mirror of https://github.com/cbeuw/Cloak
Refactor Transport and add tests
This commit is contained in:
parent
e7e4cd5726
commit
74a70a3113
|
|
@ -1,221 +1,58 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||||
"github.com/cbeuw/Cloak/internal/util"
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientHello contains every field in a ClientHello message
|
type TLS struct{}
|
||||||
type ClientHello struct {
|
|
||||||
handshakeType byte
|
|
||||||
length int
|
|
||||||
clientVersion []byte
|
|
||||||
random []byte
|
|
||||||
sessionIdLen int
|
|
||||||
sessionId []byte
|
|
||||||
cipherSuitesLen int
|
|
||||||
cipherSuites []byte
|
|
||||||
compressionMethodsLen int
|
|
||||||
compressionMethods []byte
|
|
||||||
extensionsLen int
|
|
||||||
extensions map[[2]byte][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var u16 = binary.BigEndian.Uint16
|
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
|
||||||
var u32 = binary.BigEndian.Uint32
|
|
||||||
|
|
||||||
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
|
func (TLS) String() string { return "TLS" }
|
||||||
defer func() {
|
func (TLS) HasRecordLayer() bool { return true }
|
||||||
if r := recover(); r != nil {
|
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
|
||||||
err = errors.New("Malformed Extensions")
|
|
||||||
|
func (TLS) handshake(clientHello []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) {
|
||||||
|
var ch *ClientHello
|
||||||
|
ch, err = parseClientHello(clientHello)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug(err)
|
||||||
|
err = ErrBadClientHello
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ai, err = unmarshalClientHello(ch, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
|
||||||
|
preparedConn = originalConn
|
||||||
|
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to compose TLS reply: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
_, err = preparedConn.Write(reply)
|
||||||
pointer := 0
|
if err != nil {
|
||||||
totalLen := len(input)
|
err = fmt.Errorf("failed to write TLS reply: %v", err)
|
||||||
ret = make(map[[2]byte][]byte)
|
go preparedConn.Close()
|
||||||
for pointer < totalLen {
|
return
|
||||||
var typ [2]byte
|
|
||||||
copy(typ[:], input[pointer:pointer+2])
|
|
||||||
pointer += 2
|
|
||||||
length := int(u16(input[pointer : pointer+2]))
|
|
||||||
pointer += 2
|
|
||||||
data := input[pointer : pointer+length]
|
|
||||||
pointer += length
|
|
||||||
ret[typ] = data
|
|
||||||
}
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseKeyShare(input []byte) (ret []byte, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = errors.New("malformed key_share")
|
|
||||||
}
|
}
|
||||||
}()
|
return
|
||||||
totalLen := int(u16(input[0:2]))
|
|
||||||
// 2 bytes "client key share length"
|
|
||||||
pointer := 2
|
|
||||||
for pointer < totalLen {
|
|
||||||
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
|
|
||||||
// skip "key exchange length"
|
|
||||||
pointer += 2
|
|
||||||
length := int(u16(input[pointer : pointer+2]))
|
|
||||||
pointer += 2
|
|
||||||
if length != 32 {
|
|
||||||
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
|
|
||||||
}
|
|
||||||
return input[pointer : pointer+length], nil
|
|
||||||
}
|
|
||||||
pointer += 2
|
|
||||||
length := int(u16(input[pointer : pointer+2]))
|
|
||||||
pointer += 2
|
|
||||||
_ = input[pointer : pointer+length]
|
|
||||||
pointer += length
|
|
||||||
}
|
|
||||||
return nil, errors.New("x25519 does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRecordLayer adds record layer to data
|
|
||||||
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
|
||||||
length := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
|
||||||
ret := make([]byte, 5+len(input))
|
|
||||||
copy(ret[0:1], typ)
|
|
||||||
copy(ret[1:3], ver)
|
|
||||||
copy(ret[3:5], length)
|
|
||||||
copy(ret[5:], input)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseClientHello parses everything on top of the TLS layer
|
|
||||||
// (including the record layer) into ClientHello type
|
|
||||||
func parseClientHello(data []byte) (ret *ClientHello, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = errors.New("Malformed ClientHello")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
|
|
||||||
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
peeled := make([]byte, len(data)-5)
|
|
||||||
copy(peeled, data[5:])
|
|
||||||
pointer := 0
|
|
||||||
// Handshake Type
|
|
||||||
handshakeType := peeled[pointer]
|
|
||||||
if handshakeType != 0x01 {
|
|
||||||
return ret, errors.New("Not a ClientHello")
|
|
||||||
}
|
|
||||||
pointer += 1
|
|
||||||
// Length
|
|
||||||
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
|
|
||||||
pointer += 3
|
|
||||||
if length != len(peeled[pointer:]) {
|
|
||||||
return ret, errors.New("Hello length doesn't match")
|
|
||||||
}
|
|
||||||
// Client Version
|
|
||||||
clientVersion := peeled[pointer : pointer+2]
|
|
||||||
pointer += 2
|
|
||||||
// Random
|
|
||||||
random := peeled[pointer : pointer+32]
|
|
||||||
pointer += 32
|
|
||||||
// Session ID
|
|
||||||
sessionIdLen := int(peeled[pointer])
|
|
||||||
pointer += 1
|
|
||||||
sessionId := peeled[pointer : pointer+sessionIdLen]
|
|
||||||
pointer += sessionIdLen
|
|
||||||
// Cipher Suites
|
|
||||||
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
|
|
||||||
pointer += 2
|
|
||||||
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
|
|
||||||
pointer += cipherSuitesLen
|
|
||||||
// Compression Methods
|
|
||||||
compressionMethodsLen := int(peeled[pointer])
|
|
||||||
pointer += 1
|
|
||||||
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
|
|
||||||
pointer += compressionMethodsLen
|
|
||||||
// Extensions
|
|
||||||
extensionsLen := int(u16(peeled[pointer : pointer+2]))
|
|
||||||
pointer += 2
|
|
||||||
extensions, err := parseExtensions(peeled[pointer:])
|
|
||||||
ret = &ClientHello{
|
|
||||||
handshakeType,
|
|
||||||
length,
|
|
||||||
clientVersion,
|
|
||||||
random,
|
|
||||||
sessionIdLen,
|
|
||||||
sessionId,
|
|
||||||
cipherSuitesLen,
|
|
||||||
cipherSuites,
|
|
||||||
compressionMethodsLen,
|
|
||||||
compressionMethods,
|
|
||||||
extensionsLen,
|
|
||||||
extensions,
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
|
||||||
nonce := make([]byte, 12)
|
|
||||||
rand.Read(nonce)
|
|
||||||
|
|
||||||
encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var serverHello [11][]byte
|
|
||||||
serverHello[0] = []byte{0x02} // handshake type
|
|
||||||
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
|
|
||||||
serverHello[2] = []byte{0x03, 0x03} // server version
|
|
||||||
serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
|
|
||||||
serverHello[4] = []byte{0x20} // session id length 32
|
|
||||||
serverHello[5] = sessionId // session id
|
|
||||||
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
|
||||||
serverHello[7] = []byte{0x00} // compression method null
|
|
||||||
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
|
|
||||||
|
|
||||||
keyShare, _ := hex.DecodeString("00330024001d0020")
|
|
||||||
keyExchange := make([]byte, 32)
|
|
||||||
copy(keyExchange, encryptedKey[20:48])
|
|
||||||
rand.Read(keyExchange[28:32])
|
|
||||||
serverHello[9] = append(keyShare, keyExchange...)
|
|
||||||
|
|
||||||
serverHello[10], _ = hex.DecodeString("002b00020304")
|
|
||||||
var ret []byte
|
|
||||||
for _, s := range serverHello {
|
|
||||||
ret = append(ret, s...)
|
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
|
|
||||||
// together with their respective record layers into one byte slice.
|
|
||||||
func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
|
||||||
TLS12 := []byte{0x03, 0x03}
|
|
||||||
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
|
|
||||||
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
|
||||||
cert := make([]byte, 68) // TODO: add some different lengths maybe?
|
|
||||||
rand.Read(cert)
|
|
||||||
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
|
|
||||||
ret := append(shBytes, ccsBytes...)
|
|
||||||
ret = append(ret, encryptedCertBytes...)
|
|
||||||
return ret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {
|
func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {
|
||||||
ephPub, ok := ecdh.Unmarshal(ch.random)
|
ephPub, ok := ecdh.Unmarshal(ch.random)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,215 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientHello contains every field in a ClientHello message
|
||||||
|
type ClientHello struct {
|
||||||
|
handshakeType byte
|
||||||
|
length int
|
||||||
|
clientVersion []byte
|
||||||
|
random []byte
|
||||||
|
sessionIdLen int
|
||||||
|
sessionId []byte
|
||||||
|
cipherSuitesLen int
|
||||||
|
cipherSuites []byte
|
||||||
|
compressionMethodsLen int
|
||||||
|
compressionMethods []byte
|
||||||
|
extensionsLen int
|
||||||
|
extensions map[[2]byte][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var u16 = binary.BigEndian.Uint16
|
||||||
|
var u32 = binary.BigEndian.Uint32
|
||||||
|
|
||||||
|
func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = errors.New("Malformed Extensions")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
pointer := 0
|
||||||
|
totalLen := len(input)
|
||||||
|
ret = make(map[[2]byte][]byte)
|
||||||
|
for pointer < totalLen {
|
||||||
|
var typ [2]byte
|
||||||
|
copy(typ[:], input[pointer:pointer+2])
|
||||||
|
pointer += 2
|
||||||
|
length := int(u16(input[pointer : pointer+2]))
|
||||||
|
pointer += 2
|
||||||
|
data := input[pointer : pointer+length]
|
||||||
|
pointer += length
|
||||||
|
ret[typ] = data
|
||||||
|
}
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseKeyShare(input []byte) (ret []byte, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = errors.New("malformed key_share")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
totalLen := int(u16(input[0:2]))
|
||||||
|
// 2 bytes "client key share length"
|
||||||
|
pointer := 2
|
||||||
|
for pointer < totalLen {
|
||||||
|
if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) {
|
||||||
|
// skip "key exchange length"
|
||||||
|
pointer += 2
|
||||||
|
length := int(u16(input[pointer : pointer+2]))
|
||||||
|
pointer += 2
|
||||||
|
if length != 32 {
|
||||||
|
return nil, fmt.Errorf("key share length should be 32, instead of %v", length)
|
||||||
|
}
|
||||||
|
return input[pointer : pointer+length], nil
|
||||||
|
}
|
||||||
|
pointer += 2
|
||||||
|
length := int(u16(input[pointer : pointer+2]))
|
||||||
|
pointer += 2
|
||||||
|
_ = input[pointer : pointer+length]
|
||||||
|
pointer += length
|
||||||
|
}
|
||||||
|
return nil, errors.New("x25519 does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRecordLayer adds record layer to data
|
||||||
|
func addRecordLayer(input []byte, typ []byte, ver []byte) []byte {
|
||||||
|
length := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(length, uint16(len(input)))
|
||||||
|
ret := make([]byte, 5+len(input))
|
||||||
|
copy(ret[0:1], typ)
|
||||||
|
copy(ret[1:3], ver)
|
||||||
|
copy(ret[3:5], length)
|
||||||
|
copy(ret[5:], input)
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseClientHello parses everything on top of the TLS layer
|
||||||
|
// (including the record layer) into ClientHello type
|
||||||
|
func parseClientHello(data []byte) (ret *ClientHello, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = errors.New("Malformed ClientHello")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) {
|
||||||
|
return ret, errors.New("wrong TLS1.3 handshake magic bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
peeled := make([]byte, len(data)-5)
|
||||||
|
copy(peeled, data[5:])
|
||||||
|
pointer := 0
|
||||||
|
// Handshake Type
|
||||||
|
handshakeType := peeled[pointer]
|
||||||
|
if handshakeType != 0x01 {
|
||||||
|
return ret, errors.New("Not a ClientHello")
|
||||||
|
}
|
||||||
|
pointer += 1
|
||||||
|
// Length
|
||||||
|
length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...)))
|
||||||
|
pointer += 3
|
||||||
|
if length != len(peeled[pointer:]) {
|
||||||
|
return ret, errors.New("Hello length doesn't match")
|
||||||
|
}
|
||||||
|
// Client Version
|
||||||
|
clientVersion := peeled[pointer : pointer+2]
|
||||||
|
pointer += 2
|
||||||
|
// Random
|
||||||
|
random := peeled[pointer : pointer+32]
|
||||||
|
pointer += 32
|
||||||
|
// Session ID
|
||||||
|
sessionIdLen := int(peeled[pointer])
|
||||||
|
pointer += 1
|
||||||
|
sessionId := peeled[pointer : pointer+sessionIdLen]
|
||||||
|
pointer += sessionIdLen
|
||||||
|
// Cipher Suites
|
||||||
|
cipherSuitesLen := int(u16(peeled[pointer : pointer+2]))
|
||||||
|
pointer += 2
|
||||||
|
cipherSuites := peeled[pointer : pointer+cipherSuitesLen]
|
||||||
|
pointer += cipherSuitesLen
|
||||||
|
// Compression Methods
|
||||||
|
compressionMethodsLen := int(peeled[pointer])
|
||||||
|
pointer += 1
|
||||||
|
compressionMethods := peeled[pointer : pointer+compressionMethodsLen]
|
||||||
|
pointer += compressionMethodsLen
|
||||||
|
// Extensions
|
||||||
|
extensionsLen := int(u16(peeled[pointer : pointer+2]))
|
||||||
|
pointer += 2
|
||||||
|
extensions, err := parseExtensions(peeled[pointer:])
|
||||||
|
ret = &ClientHello{
|
||||||
|
handshakeType,
|
||||||
|
length,
|
||||||
|
clientVersion,
|
||||||
|
random,
|
||||||
|
sessionIdLen,
|
||||||
|
sessionId,
|
||||||
|
cipherSuitesLen,
|
||||||
|
cipherSuites,
|
||||||
|
compressionMethodsLen,
|
||||||
|
compressionMethods,
|
||||||
|
extensionsLen,
|
||||||
|
extensions,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
||||||
|
nonce := make([]byte, 12)
|
||||||
|
rand.Read(nonce)
|
||||||
|
|
||||||
|
encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var serverHello [11][]byte
|
||||||
|
serverHello[0] = []byte{0x02} // handshake type
|
||||||
|
serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77
|
||||||
|
serverHello[2] = []byte{0x03, 0x03} // server version
|
||||||
|
serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes
|
||||||
|
serverHello[4] = []byte{0x20} // session id length 32
|
||||||
|
serverHello[5] = sessionId // session id
|
||||||
|
serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||||||
|
serverHello[7] = []byte{0x00} // compression method null
|
||||||
|
serverHello[8] = []byte{0x00, 0x2e} // extensions length 46
|
||||||
|
|
||||||
|
keyShare, _ := hex.DecodeString("00330024001d0020")
|
||||||
|
keyExchange := make([]byte, 32)
|
||||||
|
copy(keyExchange, encryptedKey[20:48])
|
||||||
|
rand.Read(keyExchange[28:32])
|
||||||
|
serverHello[9] = append(keyShare, keyExchange...)
|
||||||
|
|
||||||
|
serverHello[10], _ = hex.DecodeString("002b00020304")
|
||||||
|
var ret []byte
|
||||||
|
for _, s := range serverHello {
|
||||||
|
ret = append(ret, s...)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages
|
||||||
|
// together with their respective record layers into one byte slice.
|
||||||
|
func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) {
|
||||||
|
TLS12 := []byte{0x03, 0x03}
|
||||||
|
sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
shBytes := addRecordLayer(sh, []byte{0x16}, TLS12)
|
||||||
|
ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12)
|
||||||
|
cert := make([]byte, 68) // TODO: add some different lengths maybe?
|
||||||
|
rand.Read(cert)
|
||||||
|
encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12)
|
||||||
|
ret := append(shBytes, ccsBytes...)
|
||||||
|
ret = append(ret, encryptedCertBytes...)
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
@ -1,16 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cbeuw/Cloak/internal/util"
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
@ -35,8 +31,6 @@ const (
|
||||||
UNORDERED_FLAG = 0x01 // 0000 0001
|
UNORDERED_FLAG = 0x01 // 0000 0001
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidPubKey = errors.New("public key has invalid format")
|
|
||||||
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
|
|
||||||
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
|
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
|
||||||
var ErrUnreconisedProtocol = errors.New("unreconised protocol")
|
var ErrUnreconisedProtocol = errors.New("unreconised protocol")
|
||||||
|
|
||||||
|
|
@ -67,7 +61,6 @@ func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, e
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
|
|
||||||
var ErrReplay = errors.New("duplicate random")
|
var ErrReplay = errors.New("duplicate random")
|
||||||
var ErrBadProxyMethod = errors.New("invalid proxy method")
|
var ErrBadProxyMethod = errors.New("invalid proxy method")
|
||||||
|
|
||||||
|
|
@ -76,100 +69,34 @@ var ErrBadProxyMethod = errors.New("invalid proxy method")
|
||||||
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
|
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
|
||||||
// the handshake
|
// the handshake
|
||||||
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
|
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
|
||||||
var transport Transport
|
|
||||||
var ai authenticationInfo
|
|
||||||
switch firstPacket[0] {
|
switch firstPacket[0] {
|
||||||
case 0x47:
|
case 0x47:
|
||||||
transport = WebSocket{}
|
info.Transport = WebSocket{}
|
||||||
var req *http.Request
|
|
||||||
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket)))
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var hiddenData []byte
|
|
||||||
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
|
|
||||||
|
|
||||||
ai, err = unmarshalHidden(hiddenData, sta.staticPv)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
|
|
||||||
handler := newWsHandshakeHandler()
|
|
||||||
|
|
||||||
// For an explanation of the following 3 lines, see the comments in websocket.go
|
|
||||||
http.Serve(newWsAcceptor(conn, firstPacket), handler)
|
|
||||||
|
|
||||||
<-handler.finished
|
|
||||||
preparedConn = handler.conn
|
|
||||||
nonce := make([]byte, 12)
|
|
||||||
rand.Read(nonce)
|
|
||||||
|
|
||||||
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
|
|
||||||
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to encrypt reply: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reply := append(nonce, encryptedKey...)
|
|
||||||
_, err = preparedConn.Write(reply)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to write reply: %v", err)
|
|
||||||
go preparedConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case 0x16:
|
case 0x16:
|
||||||
transport = TLS{}
|
info.Transport = TLS{}
|
||||||
var ch *ClientHello
|
|
||||||
ch, err = parseClientHello(firstPacket)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug(err)
|
|
||||||
err = ErrBadClientHello
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if sta.registerRandom(ch.random) {
|
|
||||||
err = ErrReplay
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ai, err = unmarshalClientHello(ch, sta.staticPv)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
|
|
||||||
preparedConn = conn
|
|
||||||
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to compose TLS reply: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = preparedConn.Write(reply)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to write TLS reply: %v", err)
|
|
||||||
go preparedConn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
err = ErrUnreconisedProtocol
|
err = ErrUnreconisedProtocol
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ai authenticationInfo
|
||||||
|
ai, finisher, err = info.Transport.handshake(firstPacket, sta.staticPv, conn)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if sta.registerRandom(ai.nonce) {
|
||||||
|
err = ErrReplay
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
info, err = touchStone(ai, sta.Now)
|
info, err = touchStone(ai, sta.Now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
err = fmt.Errorf("transport %v in correct format but not Cloak: %v", transport, err)
|
err = fmt.Errorf("transport %v in correct format but not Cloak: %v", info.Transport, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
info.Transport = transport
|
|
||||||
if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
|
if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
|
||||||
err = ErrBadProxyMethod
|
err = ErrBadProxyMethod
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -98,3 +98,38 @@ func TestTouchStone(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrepareConnection(t *testing.T) {
|
||||||
|
nineSixSix := func() time.Time { return time.Unix(1565998966, 0) }
|
||||||
|
sta, _ := InitState(nineSixSix)
|
||||||
|
pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547")
|
||||||
|
p, _ := ecdh.Unmarshal(pvBytes)
|
||||||
|
sta.staticPv = p.(crypto.PrivateKey)
|
||||||
|
sta.ProxyBook["shadowsocks"] = nil
|
||||||
|
|
||||||
|
t.Run("TLS correct", func(t *testing.T) {
|
||||||
|
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
|
info, _, err := PrepareConnection(chBytes, sta, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to get client info: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if info.SessionId != 3710878841 {
|
||||||
|
t.Error("failed to get correct session id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("TLS correct but replay", func(t *testing.T) {
|
||||||
|
chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||||
|
_, _, err := PrepareConnection(chBytes, sta, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to prepare for the first time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _, err = PrepareConnection(chBytes, sta, nil)
|
||||||
|
if err != ErrReplay {
|
||||||
|
t.Errorf("failed to return ErrReplay, got %v instead", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,16 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cbeuw/Cloak/internal/util"
|
"crypto"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Transport interface {
|
type Transport interface {
|
||||||
HasRecordLayer() bool
|
HasRecordLayer() bool
|
||||||
UnitReadFunc() func(net.Conn, []byte) (int, error)
|
UnitReadFunc() func(net.Conn, []byte) (int, error)
|
||||||
|
handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (authenticationInfo, func([]byte) (net.Conn, error), error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLS struct{}
|
var ErrInvalidPubKey = errors.New("public key has invalid format")
|
||||||
|
var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
|
||||||
func (TLS) String() string { return "TLS" }
|
|
||||||
func (TLS) HasRecordLayer() bool { return true }
|
|
||||||
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
|
|
||||||
|
|
||||||
type WebSocket struct{}
|
|
||||||
|
|
||||||
func (WebSocket) String() string { return "WebSocket" }
|
|
||||||
func (WebSocket) HasRecordLayer() bool { return false }
|
|
||||||
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
|
|
||||||
|
|
|
||||||
|
|
@ -1,142 +1,69 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||||
"github.com/cbeuw/Cloak/internal/util"
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
|
type WebSocket struct{}
|
||||||
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
|
|
||||||
//
|
|
||||||
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
|
|
||||||
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
|
|
||||||
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
|
|
||||||
//
|
|
||||||
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
|
|
||||||
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
|
|
||||||
// net/http package upon receiving a request from a Conn.
|
|
||||||
//
|
|
||||||
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
|
|
||||||
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
|
|
||||||
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
|
|
||||||
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
|
|
||||||
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
|
|
||||||
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
|
|
||||||
// function.
|
|
||||||
//
|
|
||||||
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
|
|
||||||
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
|
|
||||||
// accepted.
|
|
||||||
//
|
|
||||||
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
|
|
||||||
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
|
|
||||||
// Accept will return error (so that the caller won't call again)
|
|
||||||
//
|
|
||||||
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
|
|
||||||
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
|
|
||||||
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
|
|
||||||
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
|
|
||||||
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
|
|
||||||
//
|
|
||||||
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
|
|
||||||
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
|
|
||||||
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
|
|
||||||
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
|
|
||||||
// websocket.Conn
|
|
||||||
//
|
|
||||||
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
|
|
||||||
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
|
|
||||||
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
|
|
||||||
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
|
|
||||||
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
|
|
||||||
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
|
|
||||||
//
|
|
||||||
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
|
|
||||||
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
|
|
||||||
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
|
|
||||||
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
|
|
||||||
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
|
|
||||||
// execution will block until the reference to util.WebSocketConn is ready.
|
|
||||||
|
|
||||||
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
|
func (WebSocket) String() string { return "WebSocket" }
|
||||||
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
|
func (WebSocket) HasRecordLayer() bool { return false }
|
||||||
// fake a conn that returns the first packet on first read
|
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
|
||||||
type firstBuffedConn struct {
|
|
||||||
net.Conn
|
|
||||||
firstRead bool
|
|
||||||
firstPacket []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
|
func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) {
|
||||||
if !c.firstRead {
|
var req *http.Request
|
||||||
c.firstRead = true
|
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket)))
|
||||||
copy(buf, c.firstPacket)
|
|
||||||
n := len(c.firstPacket)
|
|
||||||
c.firstPacket = []byte{}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
return c.Conn.Read(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
type wsAcceptor struct {
|
|
||||||
done bool
|
|
||||||
c *firstBuffedConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
|
|
||||||
// http.Server. This is an acceptor that accepts only one Conn
|
|
||||||
func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
|
|
||||||
f := make([]byte, len(first))
|
|
||||||
copy(f, first)
|
|
||||||
return &wsAcceptor{
|
|
||||||
c: &firstBuffedConn{Conn: conn, firstPacket: f},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wsAcceptor) Accept() (net.Conn, error) {
|
|
||||||
if w.done {
|
|
||||||
return nil, errors.New("already accepted")
|
|
||||||
}
|
|
||||||
w.done = true
|
|
||||||
return w.c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wsAcceptor) Close() error {
|
|
||||||
w.done = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wsAcceptor) Addr() net.Addr {
|
|
||||||
return w.c.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
type wsHandshakeHandler struct {
|
|
||||||
conn net.Conn
|
|
||||||
finished chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// the handler to turn a net.Conn into a websocket.Conn
|
|
||||||
func newWsHandshakeHandler() *wsHandshakeHandler {
|
|
||||||
return &wsHandshakeHandler{finished: make(chan struct{})}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
upgrader := websocket.Upgrader{}
|
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to upgrade connection to ws: %v", err)
|
err = fmt.Errorf("failed to parse first HTTP GET: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ws.conn = &util.WebSocketConn{Conn: c}
|
var hiddenData []byte
|
||||||
ws.finished <- struct{}{}
|
hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden"))
|
||||||
|
|
||||||
|
ai, err = unmarshalHidden(hiddenData, privateKey)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
|
||||||
|
handler := newWsHandshakeHandler()
|
||||||
|
|
||||||
|
// For an explanation of the following 3 lines, see the comments in websocketAux.go
|
||||||
|
http.Serve(newWsAcceptor(originalConn, reqPacket), handler)
|
||||||
|
|
||||||
|
<-handler.finished
|
||||||
|
preparedConn = handler.conn
|
||||||
|
nonce := make([]byte, 12)
|
||||||
|
rand.Read(nonce)
|
||||||
|
|
||||||
|
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
|
||||||
|
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to encrypt reply: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reply := append(nonce, encryptedKey...)
|
||||||
|
_, err = preparedConn.Write(reply)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to write reply: %v", err)
|
||||||
|
go preparedConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrBadGET = errors.New("non (or malformed) HTTP GET")
|
var ErrBadGET = errors.New("non (or malformed) HTTP GET")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,137 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/cbeuw/Cloak/internal/util"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous
|
||||||
|
// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http
|
||||||
|
//
|
||||||
|
// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format
|
||||||
|
// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a
|
||||||
|
// websocket and eventually wrap the remote Conn as util.WebSocketConn,
|
||||||
|
//
|
||||||
|
// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method
|
||||||
|
// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by
|
||||||
|
// net/http package upon receiving a request from a Conn.
|
||||||
|
//
|
||||||
|
// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should
|
||||||
|
// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a
|
||||||
|
// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet
|
||||||
|
// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that
|
||||||
|
// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the
|
||||||
|
// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then
|
||||||
|
// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn
|
||||||
|
// accepted.
|
||||||
|
//
|
||||||
|
// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface.
|
||||||
|
// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to
|
||||||
|
// Accept will return error (so that the caller won't call again)
|
||||||
|
//
|
||||||
|
// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the
|
||||||
|
// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request
|
||||||
|
// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do
|
||||||
|
// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we
|
||||||
|
// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn.
|
||||||
|
//
|
||||||
|
// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a
|
||||||
|
// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop.
|
||||||
|
// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then
|
||||||
|
// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a
|
||||||
|
// websocket.Conn
|
||||||
|
//
|
||||||
|
// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it
|
||||||
|
// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler
|
||||||
|
// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside
|
||||||
|
// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a
|
||||||
|
// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of
|
||||||
|
// WsHandshakeHandler can get the reference to the established util.WebSocketConn.
|
||||||
|
//
|
||||||
|
// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when
|
||||||
|
// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel.
|
||||||
|
// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once
|
||||||
|
// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished.
|
||||||
|
// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the
|
||||||
|
// execution will block until the reference to util.WebSocketConn is ready.
|
||||||
|
|
||||||
|
// since we need to read the first packet from the client to identify its protocol, the first packet will no longer
|
||||||
|
// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must
|
||||||
|
// fake a conn that returns the first packet on first read
|
||||||
|
type firstBuffedConn struct {
|
||||||
|
net.Conn
|
||||||
|
firstRead bool
|
||||||
|
firstPacket []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
|
||||||
|
if !c.firstRead {
|
||||||
|
c.firstRead = true
|
||||||
|
copy(buf, c.firstPacket)
|
||||||
|
n := len(c.firstPacket)
|
||||||
|
c.firstPacket = []byte{}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return c.Conn.Read(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsAcceptor struct {
|
||||||
|
done bool
|
||||||
|
c *firstBuffedConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an
|
||||||
|
// http.Server. This is an acceptor that accepts only one Conn
|
||||||
|
func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
|
||||||
|
f := make([]byte, len(first))
|
||||||
|
copy(f, first)
|
||||||
|
return &wsAcceptor{
|
||||||
|
c: &firstBuffedConn{Conn: conn, firstPacket: f},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsAcceptor) Accept() (net.Conn, error) {
|
||||||
|
if w.done {
|
||||||
|
return nil, errors.New("already accepted")
|
||||||
|
}
|
||||||
|
w.done = true
|
||||||
|
return w.c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsAcceptor) Close() error {
|
||||||
|
w.done = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsAcceptor) Addr() net.Addr {
|
||||||
|
return w.c.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsHandshakeHandler struct {
|
||||||
|
conn net.Conn
|
||||||
|
finished chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the handler to turn a net.Conn into a websocket.Conn
|
||||||
|
func newWsHandshakeHandler() *wsHandshakeHandler {
|
||||||
|
return &wsHandshakeHandler{finished: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
upgrader := websocket.Upgrader{}
|
||||||
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to upgrade connection to ws: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ws.conn = &util.WebSocketConn{Conn: c}
|
||||||
|
ws.finished <- struct{}{}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue