mirror of https://github.com/cbeuw/Cloak
Refactor Transport and add tests
parent
e7e4cd5726
commit
74a70a3113
@ -0,0 +1,215 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
// ClientHello contains every field in a ClientHello message
|
||||
type ClientHello struct {
|
||||
handshakeType byte
|
||||
length int
|
||||
clientVersion []byte
|
||||
random []byte
|
||||
sessionIdLen int
|
||||
sessionId []byte
|
||||
cipherSuitesLen int
|
||||
cipherSuites []byte
|
||||
compressionMethodsLen int
|
||||
compressionMethods []byte
|
||||
extensionsLen int
|
||||
extensions map[[2]byte][]byte
|
||||
}
|
||||
|
||||
var u16 = binary.BigEndian.Uint16
|
||||
var u32 = binary.BigEndian.Uint32
|
||||
|
||||
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed Extensions")
|
||||
}
|
||||
}()
|
||||
pointer := 0
|
||||
totalLen := len(input)
|
||||
ret = make(map[[2]byte][]byte)
|
||||
for pointer < totalLen {
|
||||
var typ [2]byte
|
||||
copy(typ[:], input[pointer:pointer+2])
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
data := input[pointer : pointer+length]
|
||||
pointer += length
|
||||
ret[typ] = data
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func parseKeyShare(input []byte) (ret []byte, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("malformed key_share")
|
||||
}
|
||||
}()
|
||||
totalLen := int(u16(input[0:2]))
|
||||
// 2 bytes "client key share length"
|
||||
pointer := 2
|
||||
for pointer < totalLen {
|
||||
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
|
||||
// skip "key exchange length"
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
if length != 32 {
|
||||
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
|
||||
}
|
||||
return input[pointer : pointer+length], nil
|
||||
}
|
||||
pointer += 2
|
||||
length := int(u16(input[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
_ = input[pointer : pointer+length]
|
||||
pointer += length
|
||||
}
|
||||
return nil, errors.New("x25519 does not exist")
|
||||
}
|
||||
|
||||
// addRecordLayer adds record layer to data
|
||||
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
||||
length := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
||||
ret := make([]byte, 5+len(input))
|
||||
copy(ret[0:1], typ)
|
||||
copy(ret[1:3], ver)
|
||||
copy(ret[3:5], length)
|
||||
copy(ret[5:], input)
|
||||
return ret
|
||||
}
|
||||
|
||||
// parseClientHello parses everything on top of the TLS layer
|
||||
// (including the record layer) into ClientHello type
|
||||
func parseClientHello(data []byte) (ret *ClientHello, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = errors.New("Malformed ClientHello")
|
||||
}
|
||||
}()
|
||||
|
||||
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
|
||||
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
|
||||
}
|
||||
|
||||
peeled := make([]byte, len(data)-5)
|
||||
copy(peeled, data[5:])
|
||||
pointer := 0
|
||||
// Handshake Type
|
||||
handshakeType := peeled[pointer]
|
||||
if handshakeType != 0x01 {
|
||||
return ret, errors.New("Not a ClientHello")
|
||||
}
|
||||
pointer += 1
|
||||
// Length
|
||||
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
|
||||
pointer += 3
|
||||
if length != len(peeled[pointer:]) {
|
||||
return ret, errors.New("Hello length doesn't match")
|
||||
}
|
||||
// Client Version
|
||||
clientVersion := peeled[pointer : pointer+2]
|
||||
pointer += 2
|
||||
// Random
|
||||
random := peeled[pointer : pointer+32]
|
||||
pointer += 32
|
||||
// Session ID
|
||||
sessionIdLen := int(peeled[pointer])
|
||||
pointer += 1
|
||||
sessionId := peeled[pointer : pointer+sessionIdLen]
|
||||
pointer += sessionIdLen
|
||||
// Cipher Suites
|
||||
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
|
||||
pointer += cipherSuitesLen
|
||||
// Compression Methods
|
||||
compressionMethodsLen := int(peeled[pointer])
|
||||
pointer += 1
|
||||
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
|
||||
pointer += compressionMethodsLen
|
||||
// Extensions
|
||||
extensionsLen := int(u16(peeled[pointer : pointer+2]))
|
||||
pointer += 2
|
||||
extensions, err := parseExtensions(peeled[pointer:])
|
||||
ret = &ClientHello{
|
||||
handshakeType,
|
||||
length,
|
||||
clientVersion,
|
||||
random,
|
||||
sessionIdLen,
|
||||
sessionId,
|
||||
cipherSuitesLen,
|
||||
cipherSuites,
|
||||
compressionMethodsLen,
|
||||
compressionMethods,
|
||||
extensionsLen,
|
||||
extensions,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
||||
nonce := make([]byte, 12)
|
||||
rand.Read(nonce)
|
||||
|
||||
encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serverHello [11][]byte
|
||||
serverHello[0] = []byte{0x02} // handshake type
|
||||
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
|
||||
serverHello[2] = []byte{0x03, 0x03} // server version
|
||||
serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
|
||||
serverHello[4] = []byte{0x20} // session id length 32
|
||||
serverHello[5] = sessionId // session id
|
||||
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||
serverHello[7] = []byte{0x00} // compression method null
|
||||
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
|
||||
|
||||
keyShare, _ := hex.DecodeString("00330024001d0020")
|
||||
keyExchange := make([]byte, 32)
|
||||
copy(keyExchange, encryptedKey[20:48])
|
||||
rand.Read(keyExchange[28:32])
|
||||
serverHello[9] = append(keyShare, keyExchange...)
|
||||
|
||||
serverHello[10], _ = hex.DecodeString("002b00020304")
|
||||
var ret []byte
|
||||
for _, s := range serverHello {
|
||||
ret = append(ret, s...)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
|
||||
// together with their respective record layers into one byte slice.
|
||||
func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
||||
TLS12 := []byte{0x03, 0x03}
|
||||
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
|
||||
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
||||
cert := make([]byte, 68) // TODO: add some different lengths maybe?
|
||||
rand.Read(cert)
|
||||
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
|
||||
ret := append(shBytes, ccsBytes...)
|
||||
ret = append(ret, encryptedCertBytes...)
|
||||
return ret, nil
|
||||
}
|
@ -1,23 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
"crypto"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Transport interface {
|
||||
HasRecordLayer() bool
|
||||
UnitReadFunc() func(net.Conn, []byte) (int, error)
|
||||
handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (authenticationInfo, func([]byte) (net.Conn, error), error)
|
||||
}
|
||||
|
||||
type TLS struct{}
|
||||
|
||||
func (TLS) String() string { return "TLS" }
|
||||
func (TLS) HasRecordLayer() bool { return true }
|
||||
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
|
||||
|
||||
type WebSocket struct{}
|
||||
|
||||
func (WebSocket) String() string { return "WebSocket" }
|
||||
func (WebSocket) HasRecordLayer() bool { return false }
|
||||
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
|
||||
var ErrInvalidPubKey = errors.New("public key has invalid format")
|
||||
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
|
||||
|
@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
"github.com/gorilla/websocket"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
|
||||
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
|
||||
//
|
||||
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
|
||||
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
|
||||
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
|
||||
//
|
||||
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
|
||||
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
|
||||
// net/http package upon receiving a request from a Conn.
|
||||
//
|
||||
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
|
||||
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
|
||||
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
|
||||
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
|
||||
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
|
||||
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
|
||||
// function.
|
||||
//
|
||||
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
|
||||
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
|
||||
// accepted.
|
||||
//
|
||||
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
|
||||
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
|
||||
// Accept will return error (so that the caller won't call again)
|
||||
//
|
||||
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
|
||||
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
|
||||
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
|
||||
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
|
||||
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
|
||||
//
|
||||
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
|
||||
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
|
||||
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
|
||||
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
|
||||
// websocket.Conn
|
||||
//
|
||||
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
|
||||
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
|
||||
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
|
||||
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
|
||||
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
|
||||
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
|
||||
//
|
||||
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
|
||||
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
|
||||
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
|
||||
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
|
||||
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
|
||||
// execution will block until the reference to util.WebSocketConn is ready.
|
||||
|
||||
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
|
||||
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
|
||||
// fake a conn that returns the first packet on first read
|
||||
type firstBuffedConn struct {
|
||||
net.Conn
|
||||
firstRead bool
|
||||
firstPacket []byte
|
||||
}
|
||||
|
||||
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
|
||||
if !c.firstRead {
|
||||
c.firstRead = true
|
||||
copy(buf, c.firstPacket)
|
||||
n := len(c.firstPacket)
|
||||
c.firstPacket = []byte{}
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(buf)
|
||||
}
|
||||
|
||||
type wsAcceptor struct {
|
||||
done bool
|
||||
c *firstBuffedConn
|
||||
}
|
||||
|
||||
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
|
||||
// http.Server. This is an acceptor that accepts only one Conn
|
||||
func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
|
||||
f := make([]byte, len(first))
|
||||
copy(f, first)
|
||||
return &wsAcceptor{
|
||||
c: &firstBuffedConn{Conn: conn, firstPacket: f},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wsAcceptor) Accept() (net.Conn, error) {
|
||||
if w.done {
|
||||
return nil, errors.New("already accepted")
|
||||
}
|
||||
w.done = true
|
||||
return w.c, nil
|
||||
}
|
||||
|
||||
func (w *wsAcceptor) Close() error {
|
||||
w.done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wsAcceptor) Addr() net.Addr {
|
||||
return w.c.LocalAddr()
|
||||
}
|
||||
|
||||
type wsHandshakeHandler struct {
|
||||
conn net.Conn
|
||||
finished chan struct{}
|
||||
}
|
||||
|
||||
// the handler to turn a net.Conn into a websocket.Conn
|
||||
func newWsHandshakeHandler() *wsHandshakeHandler {
|
||||
return &wsHandshakeHandler{finished: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to upgrade connection to ws: %v", err)
|
||||
return
|
||||
}
|
||||
ws.conn = &util.WebSocketConn{Conn: c}
|
||||
ws.finished <- struct{}{}
|
||||
}
|
Loading…
Reference in New Issue