diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 143bcf2..0f58ea1 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -58,7 +58,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { go util.Pipe(conn, webConn, 0) } - ci, finishHandshake, err := server.PrepareConnection(data, sta, conn) + ci, finishHandshake, err := server.AuthFirstPacket(data, sta) if err != nil { log.WithFields(log.Fields{ "remoteAddr": remoteAddr, @@ -84,7 +84,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { // added to the userinfo database. The distinction between going into the admin mode // and normal proxy mode is that sessionID needs == 0 for admin mode if bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 { - preparedConn, err := finishHandshake(sessionKey) + preparedConn, err := finishHandshake(conn, sessionKey) if err != nil { log.Error(err) return @@ -136,7 +136,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } if existing { - preparedConn, err := finishHandshake(sesh.SessionKey) + preparedConn, err := finishHandshake(conn, sesh.SessionKey) if err != nil { log.Error(err) return @@ -146,7 +146,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { return } - preparedConn, err := finishHandshake(sessionKey) + preparedConn, err := finishHandshake(conn, sessionKey) if err != nil { log.Error(err) return diff --git a/internal/server/TLS.go b/internal/server/TLS.go index d68620d..5111901 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -19,24 +19,29 @@ func (TLS) String() string { return "TLS" } func (TLS) HasRecordLayer() bool { return true } func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } -func (TLS) handshake(clientHello []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (fragments authFragments, finisher func([]byte) (net.Conn, error), err error) { - var ch *ClientHello - ch, err = parseClientHello(clientHello) +func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { + ch, err := parseClientHello(clientHello) if err != nil { log.Debug(err) err = ErrBadClientHello return } - fragments, err = unmarshalClientHello(ch, privateKey) + fragments, err = TLS{}.unmarshalClientHello(ch, privateKey) if err != nil { err = fmt.Errorf("failed to unmarshal ClientHello into authFragments: %v", err) return } - finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { + respond = TLS{}.makeResponder(ch.sessionId, fragments.sharedSecret[:]) + + return +} + +func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret []byte) Responder { + respond := func(originalConn net.Conn, sessionKey []byte) (preparedConn net.Conn, err error) { preparedConn = originalConn - reply, err := composeReply(ch, fragments.sharedSecret[:], sessionKey) + reply, err := composeReply(clientHelloSessionId, sharedSecret, sessionKey) if err != nil { err = fmt.Errorf("failed to compose TLS reply: %v", err) return @@ -49,11 +54,10 @@ func (TLS) handshake(clientHello []byte, privateKey crypto.PrivateKey, originalC } return } - - return + return respond } -func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fragments authFragments, err error) { +func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fragments authFragments, err error) { copy(fragments.randPubKey[:], ch.random) ephPub, ok := ecdh.Unmarshal(fragments.randPubKey[:]) if !ok { diff --git a/internal/server/TLSAux.go b/internal/server/TLSAux.go index 52f9a66..1724dc0 100644 --- a/internal/server/TLSAux.go +++ b/internal/server/TLSAux.go @@ -198,9 +198,9 @@ func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte // composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages // together with their respective record layers into one byte slice. -func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) { +func composeReply(clientHelloSessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) { TLS12 := []byte{0x03, 0x03} - sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey) + sh, err := composeServerHello(clientHelloSessionId, sharedSecret, sessionKey) if err != nil { return nil, err } diff --git a/internal/server/auth.go b/internal/server/auth.go index 9eeba4a..a1f3ff4 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "github.com/cbeuw/Cloak/internal/util" - "net" "time" log "github.com/sirupsen/logrus" @@ -34,8 +33,8 @@ const ( var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window") var ErrUnreconisedProtocol = errors.New("unreconised protocol") -// touchStone checks if a the authFragments are valid. It doesn't check if the UID is authorised -func touchStone(fragments authFragments, now func() time.Time) (info ClientInfo, err error) { +// decryptClientInfo checks if a the authFragments are valid. It doesn't check if the UID is authorised +func decryptClientInfo(fragments authFragments, now func() time.Time) (info ClientInfo, err error) { var plaintext []byte plaintext, err = util.AESGCMDecrypt(fragments.randPubKey[0:12], fragments.sharedSecret[:], fragments.ciphertextWithTag[:]) if err != nil { @@ -64,25 +63,23 @@ func touchStone(fragments authFragments, now func() time.Time) (info ClientInfo, var ErrReplay = errors.New("duplicate random") var ErrBadProxyMethod = errors.New("invalid proxy method") -// PrepareConnection checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client +// AuthFirstPacket checks if the first packet of data is ClientHello or HTTP GET, and checks if it was from a Cloak client // if it is from a Cloak client, it returns the ClientInfo with the decrypted fields. It doesn't check if the user // is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with // the handshake -func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) { +func AuthFirstPacket(firstPacket []byte, sta *State) (info ClientInfo, finisher Responder, err error) { var transport Transport switch firstPacket[0] { case 0x47: - transport = WebSocket{} + transport = &WebSocket{} case 0x16: - transport = TLS{} + transport = &TLS{} default: err = ErrUnreconisedProtocol return } - var fragments authFragments - fragments, finisher, err = transport.handshake(firstPacket, sta.staticPv, conn) - + fragments, finisher, err := transport.processFirstPacket(firstPacket, sta.staticPv) if err != nil { return } @@ -92,7 +89,7 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie return } - info, err = touchStone(fragments, sta.Now) + info, err = decryptClientInfo(fragments, sta.Now) if err != nil { log.Debug(err) err = fmt.Errorf("transport %v in correct format but not Cloak: %v", transport, err) diff --git a/internal/server/transport.go b/internal/server/transport.go index 9fbe898..87daf76 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -6,10 +6,11 @@ import ( "net" ) +type Responder = func(originalConn net.Conn, sessionKey []byte) (preparedConn net.Conn, err error) type Transport interface { HasRecordLayer() bool UnitReadFunc() func(net.Conn, []byte) (int, error) - handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (authFragments, func([]byte) (net.Conn, error), error) + processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error) } var ErrInvalidPubKey = errors.New("public key has invalid format") diff --git a/internal/server/websocket.go b/internal/server/websocket.go index 92ed109..4134bf1 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -13,13 +13,15 @@ import ( "net/http" ) -type WebSocket struct{} +type WebSocket struct { + requestPacket []byte +} func (WebSocket) String() string { return "WebSocket" } func (WebSocket) HasRecordLayer() bool { return false } func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } -func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (fragments authFragments, finisher func([]byte) (net.Conn, error), err error) { +func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { var req *http.Request req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket))) if err != nil { @@ -29,13 +31,19 @@ func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, origi var hiddenData []byte hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden")) - fragments, err = unmarshalHidden(hiddenData, privateKey) + fragments, err = WebSocket{}.unmarshalHidden(hiddenData, privateKey) if err != nil { err = fmt.Errorf("failed to unmarshal hidden data from WS into authFragments: %v", err) return } - finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { + respond = WebSocket{}.makeResponder(reqPacket, fragments.sharedSecret[:]) + + return +} + +func (WebSocket) makeResponder(reqPacket []byte, sharedSecret []byte) Responder { + respond := func(originalConn net.Conn, sessionKey []byte) (preparedConn net.Conn, err error) { handler := newWsHandshakeHandler() // For an explanation of the following 3 lines, see the comments in websocketAux.go @@ -47,7 +55,7 @@ func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, origi util.CryptoRandRead(nonce) // reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag] - encryptedKey, err := util.AESGCMEncrypt(nonce, fragments.sharedSecret[:], sessionKey) // 32 + 16 = 48 bytes + encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes if err != nil { err = fmt.Errorf("failed to encrypt reply: %v", err) return @@ -61,13 +69,12 @@ func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, origi } return } - - return + return respond } var ErrBadGET = errors.New("non (or malformed) HTTP GET") -func unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fragments authFragments, err error) { +func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fragments authFragments, err error) { if len(hidden) < 96 { err = ErrBadGET return