Working direct WebSocket transport

This commit is contained in:
Andy Wang 2019-09-01 20:23:45 +01:00
parent f47f57a59f
commit 339b324946
17 changed files with 220 additions and 105 deletions

View File

@ -53,7 +53,8 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session {
time.Sleep(time.Second * 3)
goto makeconn
}
sk, err := sta.Transport.PrepareConnection(sta, remoteConn)
var sk []byte
remoteConn, sk, err = sta.Transport.PrepareConnection(sta, remoteConn)
if err != nil {
remoteConn.Close()
log.Errorf("Failed to prepare connection to remote: %v", err)
@ -69,7 +70,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session {
log.Debug("All underlying connections established")
sessionKey := _sessionKey.Load().([]byte)
obfuscator, err := mux.GenerateObfs(sta.EncryptionMethod, sessionKey)
obfuscator, err := mux.GenerateObfs(sta.EncryptionMethod, sessionKey, sta.Transport.HasRecordLayer())
if err != nil {
log.Fatal(err)
}
@ -77,7 +78,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session {
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
UnitRead: sta.Transport.UnitReadFunc(),
Unordered: sta.Unordered,
}
sesh := mux.MakeSession(sta.SessionID, seshConfig)

View File

@ -68,7 +68,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey)
obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey, ci.Transport.HasRecordLayer())
if err != nil {
log.Error(err)
goWeb()
@ -79,7 +79,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
err = finishHandshake(sessionKey)
preparedConn, err := finishHandshake(sessionKey)
if err != nil {
log.Error(err)
return
@ -88,12 +88,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
UnitRead: ci.Transport.UnitReadFunc(),
}
sesh := mux.MakeSession(0, seshConfig)
sesh.AddConnection(conn)
sesh.AddConnection(preparedConn)
//TODO: Router could be nil in cnc mode
log.WithField("remoteAddr", conn.RemoteAddr()).Info("New admin session")
log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session")
err = http.Serve(sesh, sta.LocalAPIRouter)
if err != nil {
log.Error(err)
@ -120,7 +120,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
UnitRead: ci.Transport.UnitReadFunc(),
Unordered: ci.Unordered,
}
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
@ -131,17 +131,17 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
}
if existing {
err = finishHandshake(sesh.SessionKey)
preparedConn, err := finishHandshake(sesh.SessionKey)
if err != nil {
log.Error(err)
return
}
log.Trace("finished handshake")
sesh.AddConnection(conn)
sesh.AddConnection(preparedConn)
return
}
err = finishHandshake(sessionKey)
preparedConn, err := finishHandshake(sessionKey)
if err != nil {
log.Error(err)
return
@ -152,7 +152,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
"UID": b64(ci.UID),
"sessionID": ci.SessionId,
}).Info("New session")
sesh.AddConnection(conn)
sesh.AddConnection(preparedConn)
for {
newStream, err := sesh.Accept()

View File

@ -41,13 +41,17 @@ type TLS struct {
Transport
}
func (*TLS) HasRecordLayer() bool { return true }
func (*TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
// PrepareConnection handles the TLS handshake for a given conn and returns the sessionKey
// if the server proceed with Cloak authentication
func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err error) {
func (*TLS) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) {
preparedConn = conn
hd, sharedSecret := makeHiddenData(sta)
chOnly := sta.browser.composeClientHello(hd)
chWithRecordLayer := util.AddRecordLayer(chOnly, []byte{0x16}, []byte{0x03, 0x01})
_, err = conn.Write(chWithRecordLayer)
_, err = preparedConn.Write(chWithRecordLayer)
if err != nil {
return
}
@ -55,7 +59,7 @@ func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err
buf := make([]byte, 1024)
log.Trace("waiting for ServerHello")
_, err = util.ReadTLS(conn, buf)
_, err = util.ReadTLS(preparedConn, buf)
if err != nil {
return
}
@ -70,12 +74,12 @@ func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err
for i := 0; i < 2; i++ {
// ChangeCipherSpec and EncryptedCert (in the format of application data)
_, err = util.ReadTLS(conn, buf)
_, err = util.ReadTLS(preparedConn, buf)
if err != nil {
return
}
}
return sessionKey, nil
return preparedConn, sessionKey, nil
}

View File

@ -3,5 +3,7 @@ package client
import "net"
type Transport interface {
PrepareConnection(*State, net.Conn) ([]byte, error)
PrepareConnection(*State, net.Conn) (net.Conn, []byte, error)
HasRecordLayer() bool
UnitReadFunc() func(net.Conn, []byte) (int, error)
}

View File

@ -15,30 +15,34 @@ type WebSocket struct {
Transport
}
func (WebSocket) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err error) {
func (*WebSocket) HasRecordLayer() bool { return false }
func (*WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
func (WebSocket) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) {
preparedConn = conn
u, err := url.Parse("ws://" + sta.RemoteHost + ":" + sta.RemotePort) //TODO IPv6
if err != nil {
return nil, fmt.Errorf("failed to parse ws url: %v")
return preparedConn, nil, fmt.Errorf("failed to parse ws url: %v", err)
}
hd, sharedSecret := makeHiddenData(sta)
header := http.Header{}
header.Add("hidden", base64.StdEncoding.EncodeToString(hd.fullRaw))
c, _, err := websocket.NewClient(conn, u, header, 16480, 16480)
c, _, err := websocket.NewClient(preparedConn, u, header, 16480, 16480)
if err != nil {
return nil, fmt.Errorf("failed to handshake: %v", err)
return preparedConn, nil, fmt.Errorf("failed to handshake: %v", err)
}
conn = &util.WebSocketConn{c}
preparedConn = &util.WebSocketConn{Conn: c}
buf := make([]byte, 128)
n, err := conn.Read(buf)
n, err := preparedConn.Read(buf)
if err != nil {
return nil, fmt.Errorf("failed to read reply: %v", err)
return preparedConn, nil, fmt.Errorf("failed to read reply: %v", err)
}
if n != 60 {
return nil, errors.New("reply must be 60 bytes")
return preparedConn, nil, errors.New("reply must be 60 bytes")
}
reply := buf[:60]

View File

@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/salsa20"
)
@ -26,7 +27,11 @@ const (
E_METHOD_CHACHA20_POLY1305
)
func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Obfser {
var rlLen int
if hasRecordLayer {
rlLen = 5
}
obfs := func(f *Frame, buf []byte) (int, error) {
// we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption
// this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible
@ -41,16 +46,15 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
}
// usefulLen is the amount of bytes that will be eventually sent off
usefulLen := 5 + HEADER_LEN + len(f.Payload) + int(extraLen)
usefulLen := rlLen + HEADER_LEN + len(f.Payload) + int(extraLen)
if len(buf) < usefulLen {
return 0, errors.New("buffer is too small")
}
// we do as much in-place as possible to save allocation
useful := buf[:usefulLen] // tls header + payload + potential overhead
recordLayer := useful[0:5]
header := useful[5 : 5+HEADER_LEN]
encryptedPayloadWithExtra := useful[5+HEADER_LEN:]
useful := buf[:usefulLen] // (tls header) + payload + potential overhead
header := useful[rlLen : rlLen+HEADER_LEN]
encryptedPayloadWithExtra := useful[rlLen+HEADER_LEN:]
putU32(header[0:4], f.StreamID)
putU64(header[4:12], f.Seq)
@ -70,25 +74,32 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
nonce := encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-8:]
salsa20.XORKeyStream(header, header, nonce, &salsaKey)
if hasRecordLayer {
recordLayer := useful[0:5]
// We don't use util.AddRecordLayer here to avoid unnecessary malloc
recordLayer[0] = 0x17
recordLayer[1] = 0x03
recordLayer[2] = 0x03
binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra)))
}
// Composing final obfsed message
// We don't use util.AddRecordLayer here to avoid unnecessary malloc
recordLayer[0] = 0x17
recordLayer[1] = 0x03
recordLayer[2] = 0x03
binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra)))
return usefulLen, nil
}
return obfs
}
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Deobfser {
var rlLen int
if hasRecordLayer {
rlLen = 5
}
deobfs := func(in []byte) (*Frame, error) {
if len(in) < 5+HEADER_LEN+8 {
return nil, errors.New("Input cannot be shorter than 27 bytes")
if len(in) < rlLen+HEADER_LEN+8 {
return nil, fmt.Errorf("Input cannot be shorter than %v bytes", rlLen+HEADER_LEN+8)
}
peeled := make([]byte, len(in)-5)
copy(peeled, in[5:])
peeled := make([]byte, len(in)-rlLen)
copy(peeled, in[rlLen:])
header := peeled[:HEADER_LEN]
pldWithOverHead := peeled[HEADER_LEN:] // payload + potential overhead
@ -133,7 +144,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
return deobfs
}
func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfuscator *Obfuscator, err error) {
func GenerateObfs(encryptionMethod byte, sessionKey []byte, hasRecordLayer bool) (obfuscator *Obfuscator, err error) {
if len(sessionKey) != 32 {
err = errors.New("sessionKey size must be 32 bytes")
}
@ -165,8 +176,8 @@ func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfuscator *Obfusca
}
obfuscator = &Obfuscator{
MakeObfs(salsaKey, payloadCipher),
MakeDeobfs(salsaKey, payloadCipher),
MakeObfs(salsaKey, payloadCipher, hasRecordLayer),
MakeDeobfs(salsaKey, payloadCipher, hasRecordLayer),
sessionKey,
}
return

View File

@ -39,7 +39,15 @@ func TestGenerateObfs(t *testing.T) {
}
t.Run("plain", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey)
obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
})
t.Run("plain no record layer", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, false)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
@ -47,7 +55,15 @@ func TestGenerateObfs(t *testing.T) {
}
})
t.Run("aes-gcm", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey)
obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
run(obfuscator, t)
}
})
t.Run("aes-gcm no record layer", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, false)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
@ -55,7 +71,7 @@ func TestGenerateObfs(t *testing.T) {
}
})
t.Run("chacha20-poly1305", func(t *testing.T) {
obfuscator, err := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey)
obfuscator, err := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
if err != nil {
t.Errorf("failed to generate obfuscator %v", err)
} else {
@ -63,13 +79,13 @@ func TestGenerateObfs(t *testing.T) {
}
})
t.Run("unknown encryption method", func(t *testing.T) {
_, err := GenerateObfs(0xff, sessionKey)
_, err := GenerateObfs(0xff, sessionKey, true)
if err == nil {
t.Errorf("unknown encryption mehtod error expected")
}
})
t.Run("bad key length", func(t *testing.T) {
_, err := GenerateObfs(0xff, sessionKey[:31])
_, err := GenerateObfs(0xff, sessionKey[:31], true)
if err == nil {
t.Errorf("bad key length error expected")
}
@ -94,7 +110,7 @@ func BenchmarkObfs(b *testing.B) {
c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := obfs(testFrame, obfsBuf)
@ -109,7 +125,7 @@ func BenchmarkObfs(b *testing.B) {
c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := obfs(testFrame, obfsBuf)
@ -121,7 +137,7 @@ func BenchmarkObfs(b *testing.B) {
}
})
b.Run("plain", func(b *testing.B) {
obfs := MakeObfs(key, nil)
obfs := MakeObfs(key, nil, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := obfs(testFrame, obfsBuf)
@ -135,7 +151,7 @@ func BenchmarkObfs(b *testing.B) {
b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:16])
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := obfs(testFrame, obfsBuf)
@ -166,9 +182,9 @@ func BenchmarkDeobfs(b *testing.B) {
c, _ := aes.NewCipher(key[:])
payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
n, _ := obfs(testFrame, obfsBuf)
deobfs := MakeDeobfs(key, payloadCipher)
deobfs := MakeDeobfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -184,9 +200,9 @@ func BenchmarkDeobfs(b *testing.B) {
c, _ := aes.NewCipher(key[:16])
payloadCipher, _ := cipher.NewGCM(c)
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
n, _ := obfs(testFrame, obfsBuf)
deobfs := MakeDeobfs(key, payloadCipher)
deobfs := MakeDeobfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -199,9 +215,9 @@ func BenchmarkDeobfs(b *testing.B) {
}
})
b.Run("plain", func(b *testing.B) {
obfs := MakeObfs(key, nil)
obfs := MakeObfs(key, nil, true)
n, _ := obfs(testFrame, obfsBuf)
deobfs := MakeDeobfs(key, nil)
deobfs := MakeDeobfs(key, nil, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -216,9 +232,9 @@ func BenchmarkDeobfs(b *testing.B) {
b.Run("chacha20Poly1305", func(b *testing.B) {
payloadCipher, _ := chacha20poly1305.New(key[:16])
obfs := MakeObfs(key, payloadCipher)
obfs := MakeObfs(key, payloadCipher, true)
n, _ := obfs(testFrame, obfsBuf)
deobfs := MakeDeobfs(key, payloadCipher)
deobfs := MakeDeobfs(key, payloadCipher, true)
b.ResetTimer()
for i := 0; i < b.N; i++ {

View File

@ -30,6 +30,8 @@ type Obfuscator struct {
type switchboardStrategy int
type SessionConfig struct {
NoRecordLayer bool
*Obfuscator
Valve

View File

@ -35,7 +35,7 @@ func TestRecvDataFromRemote(t *testing.T) {
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
t.Run("plain ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -58,7 +58,7 @@ func TestRecvDataFromRemote(t *testing.T) {
}
})
t.Run("aes-gcm ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -81,7 +81,7 @@ func TestRecvDataFromRemote(t *testing.T) {
}
})
t.Run("chacha20-poly1305 ordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -105,7 +105,7 @@ func TestRecvDataFromRemote(t *testing.T) {
})
t.Run("plain unordered", func(t *testing.T) {
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigUnordered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -146,7 +146,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
rand.Read(sessionKey)
b.Run("plain", func(b *testing.B) {
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -159,7 +159,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
})
b.Run("aes-gcm", func(b *testing.B) {
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)
@ -172,7 +172,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
})
b.Run("chacha20-poly1305", func(b *testing.B) {
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey)
obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf)

View File

@ -51,7 +51,7 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream {
id: id,
session: sesh,
recvBuf: recvBuf,
obfsBuf: make([]byte, 17000),
obfsBuf: make([]byte, 17000), //TODO don't leave this hardcoded
assignedConnId: assignedConnId,
}

View File

@ -14,7 +14,7 @@ import (
func setupSesh(unordered bool) *Session {
sessionKey := make([]byte, 32)
rand.Read(sessionKey)
obfuscator, _ := GenerateObfs(0x00, sessionKey)
obfuscator, _ := GenerateObfs(0x00, sessionKey, true)
seshConfig := &SessionConfig{
Obfuscator: obfuscator,
@ -144,6 +144,7 @@ func TestStream_Read(t *testing.T) {
_, err := conn.Write(data)
if err != nil {
t.Error("cannot write to connection", err)
return
}
}
}()
@ -163,17 +164,21 @@ func TestStream_Read(t *testing.T) {
stream, err := sesh.Accept()
if err != nil {
t.Error("failed to accept stream", err)
return
}
i, err = stream.Read(buf)
if err != nil {
t.Error("failed to read", err)
return
}
if i != PAYLOAD_LEN {
t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i)
return
}
if !bytes.Equal(buf[:i], testPayload) {
t.Error("expected", testPayload,
"got", buf[:i])
return
}
})
t.Run("Nil buf", func(t *testing.T) {

View File

@ -216,11 +216,6 @@ func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]by
return ret, nil
}
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
var ErrNotCloak = errors.New("TLS but non-Cloak ClientHello")
var ErrReplay = errors.New("duplicate random")
var ErrBadProxyMethod = errors.New("invalid proxy method")
func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) {
ephPub, ok := ecdh.Unmarshal(ch.random)
if !ok {

View File

@ -22,6 +22,7 @@ type ClientInfo struct {
ProxyMethod string
EncryptionMethod byte
Unordered bool
Transport Transport
}
type authenticationInfo struct {
@ -67,14 +68,20 @@ func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, e
return
}
var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
var ErrReplay = errors.New("duplicate random")
var ErrBadProxyMethod = errors.New("invalid proxy method")
// PrepareConnection checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client
// if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
// the handshake
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) error, err error) {
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
var transport Transport
var ai authenticationInfo
switch firstPacket[0] {
case 0x47:
transport = &WebSocket{}
var req *http.Request
req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket)))
if err != nil {
@ -90,30 +97,33 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie
return
}
finisher = func(sessionKey []byte) error {
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
handler := newWsHandshakeHandler()
go http.Serve(newWsAcceptor(conn, firstPacket), handler)
http.Serve(newWsAcceptor(conn, firstPacket), handler)
<-handler.finished
conn = handler.conn
preparedConn = handler.conn
nonce := make([]byte, 12)
rand.Read(nonce)
// reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag]
encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes
if err != nil {
return fmt.Errorf("failed to encrypt reply: %v", err)
err = fmt.Errorf("failed to encrypt reply: %v", err)
return
}
reply := append(nonce, encryptedKey...)
_, err = conn.Write(reply)
_, err = preparedConn.Write(reply)
if err != nil {
go conn.Close()
return fmt.Errorf("failed to write reply: %v", err)
err = fmt.Errorf("failed to write reply: %v", err)
go preparedConn.Close()
return
}
return nil
return
}
case 0x16:
transport = &TLS{}
var ch *ClientHello
ch, err = parseClientHello(firstPacket)
if err != nil {
@ -132,17 +142,21 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie
err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err)
return
}
finisher = func(sessionKey []byte) error {
finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) {
preparedConn = conn
reply, err := composeReply(ch, ai.sharedSecret, sessionKey)
if err != nil {
return fmt.Errorf("failed to compose TLS reply: %v", err)
err = fmt.Errorf("failed to compose TLS reply: %v", err)
return
}
_, err = conn.Write(reply)
_, err = preparedConn.Write(reply)
if err != nil {
go conn.Close()
return fmt.Errorf("failed to write TLS reply: %v", err)
err = fmt.Errorf("failed to write TLS reply: %v", err)
go preparedConn.Close()
return
}
return nil
return
}
default:
err = ErrUnreconisedProtocol
@ -152,9 +166,10 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie
info, err = touchStone(ai, sta.Now)
if err != nil {
log.Debug(err)
err = ErrNotCloak
err = fmt.Errorf("transport %v in correct format but not Cloak: %v", err)
return
}
info.Transport = transport
if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
err = ErrBadProxyMethod
return

View File

@ -0,0 +1,21 @@
package server
import (
"github.com/cbeuw/Cloak/internal/util"
"net"
)
type Transport interface {
HasRecordLayer() bool
UnitReadFunc() func(net.Conn, []byte) (int, error)
}
type TLS struct{}
func (*TLS) HasRecordLayer() bool { return true }
func (*TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
type WebSocket struct{}
func (*WebSocket) HasRecordLayer() bool { return false }
func (*WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }

View File

@ -21,12 +21,13 @@ type firstBuffedConn struct {
func (c *firstBuffedConn) Read(buf []byte) (int, error) {
if !c.firstRead {
c.firstRead = true
copy(buf, c.firstPacket)
n := len(c.firstPacket)
c.firstPacket = []byte{}
return n, nil
}
return c.Read(buf)
return c.Conn.Read(buf)
}
type wsAcceptor struct {
@ -38,7 +39,7 @@ func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor {
f := make([]byte, len(first))
copy(f, first)
return &wsAcceptor{
c: &firstBuffedConn{Conn: conn, firstPacket: first},
c: &firstBuffedConn{Conn: conn, firstPacket: f},
}
}
@ -69,16 +70,13 @@ func newWsHandshakeHandler() *wsHandshakeHandler {
}
func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{
ReadBufferSize: 16380,
WriteBufferSize: 16380,
}
upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("failed to upgrade connection to ws: %v", err)
return
}
ws.conn = &util.WebSocketConn{c}
ws.conn = &util.WebSocketConn{Conn: c}
ws.finished <- struct{}{}
}

View File

@ -89,10 +89,10 @@ func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
}
func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) {
// The maximum size of TLS message will be 16380+12+16. 12 because of the stream header and 16
// The maximum size of TLS message will be 16380+14+16. 14 because of the stream header and 16
// because of the salt/mac
// 16408 is the max TLS message size on Firefox
buf := make([]byte, 16380)
buf := make([]byte, 16378)
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
}

View File

@ -1,16 +1,23 @@
package util
import (
"errors"
"github.com/gorilla/websocket"
"io"
"net"
"sync"
"time"
)
type WebSocketConn struct {
*websocket.Conn
writeM sync.Mutex
}
func (ws *WebSocketConn) Write(data []byte) (int, error) {
ws.writeM.Lock()
err := ws.WriteMessage(websocket.BinaryMessage, data)
ws.writeM.Unlock()
if err != nil {
return 0, err
} else {
@ -18,12 +25,41 @@ func (ws *WebSocketConn) Write(data []byte) (int, error) {
}
}
func (ws *WebSocketConn) Read(buf []byte) (int, error) {
_, r, err := ws.NextReader()
func (ws *WebSocketConn) Read(buf []byte) (n int, err error) {
t, r, err := ws.NextReader()
if err != nil {
return 0, err
}
return r.Read(buf)
if t != websocket.BinaryMessage {
return 0, nil
}
// Read until io.EOL for one full message
for {
var read int
read, err = r.Read(buf[n:])
if err != nil {
if err == io.EOF {
err = nil
break
} else {
break
}
} else {
// There may be data available to read but n == len(buf)-1, read==0 because buffer is full
if read == 0 {
err = errors.New("nothing more is read. message may be larger than buffer")
break
}
}
n += read
}
return
}
func (ws *WebSocketConn) Close() error {
ws.writeM.Lock()
defer ws.writeM.Unlock()
return ws.Conn.Close()
}
func (ws *WebSocketConn) SetDeadline(t time.Time) error {
@ -37,3 +73,8 @@ func (ws *WebSocketConn) SetDeadline(t time.Time) error {
}
return nil
}
// ws unit reader
func ReadWebSocket(conn net.Conn, buffer []byte) (n int, err error) {
return conn.Read(buffer)
}