From 339b324946c8b6eda2d5b189b1c5426c32c59ad0 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 1 Sep 2019 20:23:45 +0100 Subject: [PATCH] Working direct WebSocket transport --- cmd/ck-client/ck-client.go | 7 +++-- cmd/ck-server/ck-server.go | 20 ++++++------ internal/client/TLS.go | 14 ++++++--- internal/client/transport.go | 4 ++- internal/client/websocket.go | 20 +++++++----- internal/multiplex/obfs.go | 49 +++++++++++++++++------------ internal/multiplex/obfs_test.go | 50 ++++++++++++++++++++---------- internal/multiplex/session.go | 2 ++ internal/multiplex/session_test.go | 14 ++++----- internal/multiplex/stream.go | 2 +- internal/multiplex/stream_test.go | 7 ++++- internal/server/TLS.go | 5 --- internal/server/auth.go | 47 ++++++++++++++++++---------- internal/server/transport.go | 21 +++++++++++++ internal/server/websocket.go | 12 +++---- internal/util/util.go | 4 +-- internal/util/websocket.go | 47 ++++++++++++++++++++++++++-- 17 files changed, 220 insertions(+), 105 deletions(-) create mode 100644 internal/server/transport.go diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index bf9f9a4..521d551 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -53,7 +53,8 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session { time.Sleep(time.Second * 3) goto makeconn } - sk, err := sta.Transport.PrepareConnection(sta, remoteConn) + var sk []byte + remoteConn, sk, err = sta.Transport.PrepareConnection(sta, remoteConn) if err != nil { remoteConn.Close() log.Errorf("Failed to prepare connection to remote: %v", err) @@ -69,7 +70,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session { log.Debug("All underlying connections established") sessionKey := _sessionKey.Load().([]byte) - obfuscator, err := mux.GenerateObfs(sta.EncryptionMethod, sessionKey) + obfuscator, err := mux.GenerateObfs(sta.EncryptionMethod, sessionKey, sta.Transport.HasRecordLayer()) if err != nil { log.Fatal(err) } @@ -77,7 +78,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session { seshConfig := &mux.SessionConfig{ Obfuscator: obfuscator, Valve: nil, - UnitRead: util.ReadTLS, + UnitRead: sta.Transport.UnitReadFunc(), Unordered: sta.Unordered, } sesh := mux.MakeSession(sta.SessionID, seshConfig) diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 52c568d..6ecdd0a 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -68,7 +68,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { sessionKey := make([]byte, 32) rand.Read(sessionKey) - obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey) + obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey, ci.Transport.HasRecordLayer()) if err != nil { log.Error(err) goWeb() @@ -79,7 +79,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { // added to the userinfo database. The distinction between going into the admin mode // and normal proxy mode is that sessionID needs == 0 for admin mode if bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 { - err = finishHandshake(sessionKey) + preparedConn, err := finishHandshake(sessionKey) if err != nil { log.Error(err) return @@ -88,12 +88,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) { seshConfig := &mux.SessionConfig{ Obfuscator: obfuscator, Valve: nil, - UnitRead: util.ReadTLS, + UnitRead: ci.Transport.UnitReadFunc(), } sesh := mux.MakeSession(0, seshConfig) - sesh.AddConnection(conn) + sesh.AddConnection(preparedConn) //TODO: Router could be nil in cnc mode - log.WithField("remoteAddr", conn.RemoteAddr()).Info("New admin session") + log.WithField("remoteAddr", preparedConn.RemoteAddr()).Info("New admin session") err = http.Serve(sesh, sta.LocalAPIRouter) if err != nil { log.Error(err) @@ -120,7 +120,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { seshConfig := &mux.SessionConfig{ Obfuscator: obfuscator, Valve: nil, - UnitRead: util.ReadTLS, + UnitRead: ci.Transport.UnitReadFunc(), Unordered: ci.Unordered, } sesh, existing, err := user.GetSession(ci.SessionId, seshConfig) @@ -131,17 +131,17 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } if existing { - err = finishHandshake(sesh.SessionKey) + preparedConn, err := finishHandshake(sesh.SessionKey) if err != nil { log.Error(err) return } log.Trace("finished handshake") - sesh.AddConnection(conn) + sesh.AddConnection(preparedConn) return } - err = finishHandshake(sessionKey) + preparedConn, err := finishHandshake(sessionKey) if err != nil { log.Error(err) return @@ -152,7 +152,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { "UID": b64(ci.UID), "sessionID": ci.SessionId, }).Info("New session") - sesh.AddConnection(conn) + sesh.AddConnection(preparedConn) for { newStream, err := sesh.Accept() diff --git a/internal/client/TLS.go b/internal/client/TLS.go index e7d8470..162c50f 100644 --- a/internal/client/TLS.go +++ b/internal/client/TLS.go @@ -41,13 +41,17 @@ type TLS struct { Transport } +func (*TLS) HasRecordLayer() bool { return true } +func (*TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } + // PrepareConnection handles the TLS handshake for a given conn and returns the sessionKey // if the server proceed with Cloak authentication -func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err error) { +func (*TLS) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) { + preparedConn = conn hd, sharedSecret := makeHiddenData(sta) chOnly := sta.browser.composeClientHello(hd) chWithRecordLayer := util.AddRecordLayer(chOnly, []byte{0x16}, []byte{0x03, 0x01}) - _, err = conn.Write(chWithRecordLayer) + _, err = preparedConn.Write(chWithRecordLayer) if err != nil { return } @@ -55,7 +59,7 @@ func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err buf := make([]byte, 1024) log.Trace("waiting for ServerHello") - _, err = util.ReadTLS(conn, buf) + _, err = util.ReadTLS(preparedConn, buf) if err != nil { return } @@ -70,12 +74,12 @@ func (*TLS) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err for i := 0; i < 2; i++ { // ChangeCipherSpec and EncryptedCert (in the format of application data) - _, err = util.ReadTLS(conn, buf) + _, err = util.ReadTLS(preparedConn, buf) if err != nil { return } } - return sessionKey, nil + return preparedConn, sessionKey, nil } diff --git a/internal/client/transport.go b/internal/client/transport.go index 7cd7a66..411c3de 100644 --- a/internal/client/transport.go +++ b/internal/client/transport.go @@ -3,5 +3,7 @@ package client import "net" type Transport interface { - PrepareConnection(*State, net.Conn) ([]byte, error) + PrepareConnection(*State, net.Conn) (net.Conn, []byte, error) + HasRecordLayer() bool + UnitReadFunc() func(net.Conn, []byte) (int, error) } diff --git a/internal/client/websocket.go b/internal/client/websocket.go index d9a3c04..bc1315a 100644 --- a/internal/client/websocket.go +++ b/internal/client/websocket.go @@ -15,30 +15,34 @@ type WebSocket struct { Transport } -func (WebSocket) PrepareConnection(sta *State, conn net.Conn) (sessionKey []byte, err error) { +func (*WebSocket) HasRecordLayer() bool { return false } +func (*WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } + +func (WebSocket) PrepareConnection(sta *State, conn net.Conn) (preparedConn net.Conn, sessionKey []byte, err error) { + preparedConn = conn u, err := url.Parse("ws://" + sta.RemoteHost + ":" + sta.RemotePort) //TODO IPv6 if err != nil { - return nil, fmt.Errorf("failed to parse ws url: %v") + return preparedConn, nil, fmt.Errorf("failed to parse ws url: %v", err) } hd, sharedSecret := makeHiddenData(sta) header := http.Header{} header.Add("hidden", base64.StdEncoding.EncodeToString(hd.fullRaw)) - c, _, err := websocket.NewClient(conn, u, header, 16480, 16480) + c, _, err := websocket.NewClient(preparedConn, u, header, 16480, 16480) if err != nil { - return nil, fmt.Errorf("failed to handshake: %v", err) + return preparedConn, nil, fmt.Errorf("failed to handshake: %v", err) } - conn = &util.WebSocketConn{c} + preparedConn = &util.WebSocketConn{Conn: c} buf := make([]byte, 128) - n, err := conn.Read(buf) + n, err := preparedConn.Read(buf) if err != nil { - return nil, fmt.Errorf("failed to read reply: %v", err) + return preparedConn, nil, fmt.Errorf("failed to read reply: %v", err) } if n != 60 { - return nil, errors.New("reply must be 60 bytes") + return preparedConn, nil, errors.New("reply must be 60 bytes") } reply := buf[:60] diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 1a27829..c8479a1 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/binary" "errors" + "fmt" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/salsa20" ) @@ -26,7 +27,11 @@ const ( E_METHOD_CHACHA20_POLY1305 ) -func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { +func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Obfser { + var rlLen int + if hasRecordLayer { + rlLen = 5 + } obfs := func(f *Frame, buf []byte) (int, error) { // we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption // this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible @@ -41,16 +46,15 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { } // usefulLen is the amount of bytes that will be eventually sent off - usefulLen := 5 + HEADER_LEN + len(f.Payload) + int(extraLen) + usefulLen := rlLen + HEADER_LEN + len(f.Payload) + int(extraLen) if len(buf) < usefulLen { return 0, errors.New("buffer is too small") } // we do as much in-place as possible to save allocation - useful := buf[:usefulLen] // tls header + payload + potential overhead - recordLayer := useful[0:5] - header := useful[5 : 5+HEADER_LEN] - encryptedPayloadWithExtra := useful[5+HEADER_LEN:] + useful := buf[:usefulLen] // (tls header) + payload + potential overhead + header := useful[rlLen : rlLen+HEADER_LEN] + encryptedPayloadWithExtra := useful[rlLen+HEADER_LEN:] putU32(header[0:4], f.StreamID) putU64(header[4:12], f.Seq) @@ -70,25 +74,32 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { nonce := encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-8:] salsa20.XORKeyStream(header, header, nonce, &salsaKey) + if hasRecordLayer { + recordLayer := useful[0:5] + // We don't use util.AddRecordLayer here to avoid unnecessary malloc + recordLayer[0] = 0x17 + recordLayer[1] = 0x03 + recordLayer[2] = 0x03 + binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra))) + } // Composing final obfsed message - // We don't use util.AddRecordLayer here to avoid unnecessary malloc - recordLayer[0] = 0x17 - recordLayer[1] = 0x03 - recordLayer[2] = 0x03 - binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra))) return usefulLen, nil } return obfs } -func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { +func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Deobfser { + var rlLen int + if hasRecordLayer { + rlLen = 5 + } deobfs := func(in []byte) (*Frame, error) { - if len(in) < 5+HEADER_LEN+8 { - return nil, errors.New("Input cannot be shorter than 27 bytes") + if len(in) < rlLen+HEADER_LEN+8 { + return nil, fmt.Errorf("Input cannot be shorter than %v bytes", rlLen+HEADER_LEN+8) } - peeled := make([]byte, len(in)-5) - copy(peeled, in[5:]) + peeled := make([]byte, len(in)-rlLen) + copy(peeled, in[rlLen:]) header := peeled[:HEADER_LEN] pldWithOverHead := peeled[HEADER_LEN:] // payload + potential overhead @@ -133,7 +144,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { return deobfs } -func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfuscator *Obfuscator, err error) { +func GenerateObfs(encryptionMethod byte, sessionKey []byte, hasRecordLayer bool) (obfuscator *Obfuscator, err error) { if len(sessionKey) != 32 { err = errors.New("sessionKey size must be 32 bytes") } @@ -165,8 +176,8 @@ func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfuscator *Obfusca } obfuscator = &Obfuscator{ - MakeObfs(salsaKey, payloadCipher), - MakeDeobfs(salsaKey, payloadCipher), + MakeObfs(salsaKey, payloadCipher, hasRecordLayer), + MakeDeobfs(salsaKey, payloadCipher, hasRecordLayer), sessionKey, } return diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 1dddd4d..6a1e77d 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -39,7 +39,15 @@ func TestGenerateObfs(t *testing.T) { } t.Run("plain", func(t *testing.T) { - obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey) + obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) + if err != nil { + t.Errorf("failed to generate obfuscator %v", err) + } else { + run(obfuscator, t) + } + }) + t.Run("plain no record layer", func(t *testing.T) { + obfuscator, err := GenerateObfs(E_METHOD_PLAIN, sessionKey, false) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } else { @@ -47,7 +55,15 @@ func TestGenerateObfs(t *testing.T) { } }) t.Run("aes-gcm", func(t *testing.T) { - obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey) + obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true) + if err != nil { + t.Errorf("failed to generate obfuscator %v", err) + } else { + run(obfuscator, t) + } + }) + t.Run("aes-gcm no record layer", func(t *testing.T) { + obfuscator, err := GenerateObfs(E_METHOD_AES_GCM, sessionKey, false) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } else { @@ -55,7 +71,7 @@ func TestGenerateObfs(t *testing.T) { } }) t.Run("chacha20-poly1305", func(t *testing.T) { - obfuscator, err := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey) + obfuscator, err := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } else { @@ -63,13 +79,13 @@ func TestGenerateObfs(t *testing.T) { } }) t.Run("unknown encryption method", func(t *testing.T) { - _, err := GenerateObfs(0xff, sessionKey) + _, err := GenerateObfs(0xff, sessionKey, true) if err == nil { t.Errorf("unknown encryption mehtod error expected") } }) t.Run("bad key length", func(t *testing.T) { - _, err := GenerateObfs(0xff, sessionKey[:31]) + _, err := GenerateObfs(0xff, sessionKey[:31], true) if err == nil { t.Errorf("bad key length error expected") } @@ -94,7 +110,7 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { n, err := obfs(testFrame, obfsBuf) @@ -109,7 +125,7 @@ func BenchmarkObfs(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { n, err := obfs(testFrame, obfsBuf) @@ -121,7 +137,7 @@ func BenchmarkObfs(b *testing.B) { } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) + obfs := MakeObfs(key, nil, true) b.ResetTimer() for i := 0; i < b.N; i++ { n, err := obfs(testFrame, obfsBuf) @@ -135,7 +151,7 @@ func BenchmarkObfs(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) { payloadCipher, _ := chacha20poly1305.New(key[:16]) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { n, err := obfs(testFrame, obfsBuf) @@ -166,9 +182,9 @@ func BenchmarkDeobfs(b *testing.B) { c, _ := aes.NewCipher(key[:]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) n, _ := obfs(testFrame, obfsBuf) - deobfs := MakeDeobfs(key, payloadCipher) + deobfs := MakeDeobfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -184,9 +200,9 @@ func BenchmarkDeobfs(b *testing.B) { c, _ := aes.NewCipher(key[:16]) payloadCipher, _ := cipher.NewGCM(c) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) n, _ := obfs(testFrame, obfsBuf) - deobfs := MakeDeobfs(key, payloadCipher) + deobfs := MakeDeobfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -199,9 +215,9 @@ func BenchmarkDeobfs(b *testing.B) { } }) b.Run("plain", func(b *testing.B) { - obfs := MakeObfs(key, nil) + obfs := MakeObfs(key, nil, true) n, _ := obfs(testFrame, obfsBuf) - deobfs := MakeDeobfs(key, nil) + deobfs := MakeDeobfs(key, nil, true) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -216,9 +232,9 @@ func BenchmarkDeobfs(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) { payloadCipher, _ := chacha20poly1305.New(key[:16]) - obfs := MakeObfs(key, payloadCipher) + obfs := MakeObfs(key, payloadCipher, true) n, _ := obfs(testFrame, obfsBuf) - deobfs := MakeDeobfs(key, payloadCipher) + deobfs := MakeDeobfs(key, payloadCipher, true) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index fc8ce46..40bcb12 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -30,6 +30,8 @@ type Obfuscator struct { type switchboardStrategy int type SessionConfig struct { + NoRecordLayer bool + *Obfuscator Valve diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index d52837d..d1579d5 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -35,7 +35,7 @@ func TestRecvDataFromRemote(t *testing.T) { sessionKey := make([]byte, 32) rand.Read(sessionKey) t.Run("plain ordered", func(t *testing.T) { - obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -58,7 +58,7 @@ func TestRecvDataFromRemote(t *testing.T) { } }) t.Run("aes-gcm ordered", func(t *testing.T) { - obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -81,7 +81,7 @@ func TestRecvDataFromRemote(t *testing.T) { } }) t.Run("chacha20-poly1305 ordered", func(t *testing.T) { - obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -105,7 +105,7 @@ func TestRecvDataFromRemote(t *testing.T) { }) t.Run("plain unordered", func(t *testing.T) { - obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) seshConfigUnordered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -146,7 +146,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { rand.Read(sessionKey) b.Run("plain", func(b *testing.B) { - obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_PLAIN, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -159,7 +159,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { }) b.Run("aes-gcm", func(b *testing.B) { - obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_AES_GCM, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) @@ -172,7 +172,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { }) b.Run("chacha20-poly1305", func(b *testing.B) { - obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey) + obfuscator, _ := GenerateObfs(E_METHOD_CHACHA20_POLY1305, sessionKey, true) seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfigOrdered) n, _ := sesh.Obfs(f, obfsBuf) diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 016fe30..77179dd 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -51,7 +51,7 @@ func makeStream(sesh *Session, id uint32, assignedConnId uint32) *Stream { id: id, session: sesh, recvBuf: recvBuf, - obfsBuf: make([]byte, 17000), + obfsBuf: make([]byte, 17000), //TODO don't leave this hardcoded assignedConnId: assignedConnId, } diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 96f521d..1592fed 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -14,7 +14,7 @@ import ( func setupSesh(unordered bool) *Session { sessionKey := make([]byte, 32) rand.Read(sessionKey) - obfuscator, _ := GenerateObfs(0x00, sessionKey) + obfuscator, _ := GenerateObfs(0x00, sessionKey, true) seshConfig := &SessionConfig{ Obfuscator: obfuscator, @@ -144,6 +144,7 @@ func TestStream_Read(t *testing.T) { _, err := conn.Write(data) if err != nil { t.Error("cannot write to connection", err) + return } } }() @@ -163,17 +164,21 @@ func TestStream_Read(t *testing.T) { stream, err := sesh.Accept() if err != nil { t.Error("failed to accept stream", err) + return } i, err = stream.Read(buf) if err != nil { t.Error("failed to read", err) + return } if i != PAYLOAD_LEN { t.Errorf("expected read %v, got %v", PAYLOAD_LEN, i) + return } if !bytes.Equal(buf[:i], testPayload) { t.Error("expected", testPayload, "got", buf[:i]) + return } }) t.Run("Nil buf", func(t *testing.T) { diff --git a/internal/server/TLS.go b/internal/server/TLS.go index 09ac687..081609d 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -216,11 +216,6 @@ func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]by return ret, nil } -var ErrBadClientHello = errors.New("non (or malformed) ClientHello") -var ErrNotCloak = errors.New("TLS but non-Cloak ClientHello") -var ErrReplay = errors.New("duplicate random") -var ErrBadProxyMethod = errors.New("invalid proxy method") - func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) { ephPub, ok := ecdh.Unmarshal(ch.random) if !ok { diff --git a/internal/server/auth.go b/internal/server/auth.go index c537a16..07963ae 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -22,6 +22,7 @@ type ClientInfo struct { ProxyMethod string EncryptionMethod byte Unordered bool + Transport Transport } type authenticationInfo struct { @@ -67,14 +68,20 @@ func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, e return } +var ErrBadClientHello = errors.New("non (or malformed) ClientHello") +var ErrReplay = errors.New("duplicate random") +var ErrBadProxyMethod = errors.New("invalid proxy method") + // PrepareConnection checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client // if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user // is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with // the handshake -func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) error, err error) { +func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) { + var transport Transport var ai authenticationInfo switch firstPacket[0] { case 0x47: + transport = &WebSocket{} var req *http.Request req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket))) if err != nil { @@ -90,30 +97,33 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie return } - finisher = func(sessionKey []byte) error { + finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { handler := newWsHandshakeHandler() - go http.Serve(newWsAcceptor(conn, firstPacket), handler) + http.Serve(newWsAcceptor(conn, firstPacket), handler) <-handler.finished - conn = handler.conn + preparedConn = handler.conn nonce := make([]byte, 12) rand.Read(nonce) // reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag] encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes if err != nil { - return fmt.Errorf("failed to encrypt reply: %v", err) + err = fmt.Errorf("failed to encrypt reply: %v", err) + return } reply := append(nonce, encryptedKey...) - _, err = conn.Write(reply) + _, err = preparedConn.Write(reply) if err != nil { - go conn.Close() - return fmt.Errorf("failed to write reply: %v", err) + err = fmt.Errorf("failed to write reply: %v", err) + go preparedConn.Close() + return } - return nil + return } case 0x16: + transport = &TLS{} var ch *ClientHello ch, err = parseClientHello(firstPacket) if err != nil { @@ -132,17 +142,21 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err) return } - finisher = func(sessionKey []byte) error { + + finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { + preparedConn = conn reply, err := composeReply(ch, ai.sharedSecret, sessionKey) if err != nil { - return fmt.Errorf("failed to compose TLS reply: %v", err) + err = fmt.Errorf("failed to compose TLS reply: %v", err) + return } - _, err = conn.Write(reply) + _, err = preparedConn.Write(reply) if err != nil { - go conn.Close() - return fmt.Errorf("failed to write TLS reply: %v", err) + err = fmt.Errorf("failed to write TLS reply: %v", err) + go preparedConn.Close() + return } - return nil + return } default: err = ErrUnreconisedProtocol @@ -152,9 +166,10 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie info, err = touchStone(ai, sta.Now) if err != nil { log.Debug(err) - err = ErrNotCloak + err = fmt.Errorf("transport %v in correct format but not Cloak: %v", err) return } + info.Transport = transport if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok { err = ErrBadProxyMethod return diff --git a/internal/server/transport.go b/internal/server/transport.go new file mode 100644 index 0000000..b956f05 --- /dev/null +++ b/internal/server/transport.go @@ -0,0 +1,21 @@ +package server + +import ( + "github.com/cbeuw/Cloak/internal/util" + "net" +) + +type Transport interface { + HasRecordLayer() bool + UnitReadFunc() func(net.Conn, []byte) (int, error) +} + +type TLS struct{} + +func (*TLS) HasRecordLayer() bool { return true } +func (*TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } + +type WebSocket struct{} + +func (*WebSocket) HasRecordLayer() bool { return false } +func (*WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } diff --git a/internal/server/websocket.go b/internal/server/websocket.go index e9bee82..020b424 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -21,12 +21,13 @@ type firstBuffedConn struct { func (c *firstBuffedConn) Read(buf []byte) (int, error) { if !c.firstRead { + c.firstRead = true copy(buf, c.firstPacket) n := len(c.firstPacket) c.firstPacket = []byte{} return n, nil } - return c.Read(buf) + return c.Conn.Read(buf) } type wsAcceptor struct { @@ -38,7 +39,7 @@ func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor { f := make([]byte, len(first)) copy(f, first) return &wsAcceptor{ - c: &firstBuffedConn{Conn: conn, firstPacket: first}, + c: &firstBuffedConn{Conn: conn, firstPacket: f}, } } @@ -69,16 +70,13 @@ func newWsHandshakeHandler() *wsHandshakeHandler { } func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{ - ReadBufferSize: 16380, - WriteBufferSize: 16380, - } + upgrader := websocket.Upgrader{} c, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("failed to upgrade connection to ws: %v", err) return } - ws.conn = &util.WebSocketConn{c} + ws.conn = &util.WebSocketConn{Conn: c} ws.finished <- struct{}{} } diff --git a/internal/util/util.go b/internal/util/util.go index 4d4c096..add07a1 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -89,10 +89,10 @@ func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte { } func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) { - // The maximum size of TLS message will be 16380+12+16. 12 because of the stream header and 16 + // The maximum size of TLS message will be 16380+14+16. 14 because of the stream header and 16 // because of the salt/mac // 16408 is the max TLS message size on Firefox - buf := make([]byte, 16380) + buf := make([]byte, 16378) if srcReadTimeout != 0 { src.SetReadDeadline(time.Now().Add(srcReadTimeout)) } diff --git a/internal/util/websocket.go b/internal/util/websocket.go index 37ddfa0..a3597ca 100644 --- a/internal/util/websocket.go +++ b/internal/util/websocket.go @@ -1,16 +1,23 @@ package util import ( + "errors" "github.com/gorilla/websocket" + "io" + "net" + "sync" "time" ) type WebSocketConn struct { *websocket.Conn + writeM sync.Mutex } func (ws *WebSocketConn) Write(data []byte) (int, error) { + ws.writeM.Lock() err := ws.WriteMessage(websocket.BinaryMessage, data) + ws.writeM.Unlock() if err != nil { return 0, err } else { @@ -18,12 +25,41 @@ func (ws *WebSocketConn) Write(data []byte) (int, error) { } } -func (ws *WebSocketConn) Read(buf []byte) (int, error) { - _, r, err := ws.NextReader() +func (ws *WebSocketConn) Read(buf []byte) (n int, err error) { + t, r, err := ws.NextReader() if err != nil { return 0, err } - return r.Read(buf) + if t != websocket.BinaryMessage { + return 0, nil + } + + // Read until io.EOL for one full message + for { + var read int + read, err = r.Read(buf[n:]) + if err != nil { + if err == io.EOF { + err = nil + break + } else { + break + } + } else { + // There may be data available to read but n == len(buf)-1, read==0 because buffer is full + if read == 0 { + err = errors.New("nothing more is read. message may be larger than buffer") + break + } + } + n += read + } + return +} +func (ws *WebSocketConn) Close() error { + ws.writeM.Lock() + defer ws.writeM.Unlock() + return ws.Conn.Close() } func (ws *WebSocketConn) SetDeadline(t time.Time) error { @@ -37,3 +73,8 @@ func (ws *WebSocketConn) SetDeadline(t time.Time) error { } return nil } + +// ws unit reader +func ReadWebSocket(conn net.Conn, buffer []byte) (n int, err error) { + return conn.Read(buffer) +}