diff --git a/internal/server/TLS.go b/internal/server/TLS.go index 081609d..e9fe36d 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -1,221 +1,58 @@ package server import ( - "bytes" "crypto" - "crypto/rand" - "encoding/binary" - "encoding/hex" "errors" "fmt" "github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/util" + "net" + + log "github.com/sirupsen/logrus" ) -// ClientHello contains every field in a ClientHello message -type ClientHello struct { - handshakeType byte - length int - clientVersion []byte - random []byte - sessionIdLen int - sessionId []byte - cipherSuitesLen int - cipherSuites []byte - compressionMethodsLen int - compressionMethods []byte - extensionsLen int - extensions map[[2]byte][]byte -} +type TLS struct{} -var u16 = binary.BigEndian.Uint16 -var u32 = binary.BigEndian.Uint32 +var ErrBadClientHello = errors.New("non (or malformed) ClientHello") -func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) { - defer func() { - if r := recover(); r != nil { - err = errors.New("Malformed Extensions") +func (TLS) String() string { return "TLS" } +func (TLS) HasRecordLayer() bool { return true } +func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } + +func (TLS) handshake(clientHello []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) { + var ch *ClientHello + ch, err = parseClientHello(clientHello) + if err != nil { + log.Debug(err) + err = ErrBadClientHello + return + } + + ai, err = unmarshalClientHello(ch, privateKey) + if err != nil { + err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err) + return + } + + finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { + preparedConn = originalConn + reply, err := composeReply(ch, ai.sharedSecret, sessionKey) + if err != nil { + err = fmt.Errorf("failed to compose TLS reply: %v", err) + return } - }() - pointer := 0 - totalLen := len(input) - ret = make(map[[2]byte][]byte) - for pointer < totalLen { - var typ [2]byte - copy(typ[:], input[pointer:pointer+2]) - pointer += 2 - length := int(u16(input[pointer : pointer+2])) - pointer += 2 - data := input[pointer : pointer+length] - pointer += length - ret[typ] = data - } - return ret, err -} - -func parseKeyShare(input []byte) (ret []byte, err error) { - defer func() { - if r := recover(); r != nil { - err = errors.New("malformed key_share") + _, err = preparedConn.Write(reply) + if err != nil { + err = fmt.Errorf("failed to write TLS reply: %v", err) + go preparedConn.Close() + return } - }() - totalLen := int(u16(input[0:2])) - // 2 bytes "client key share length" - pointer := 2 - for pointer < totalLen { - if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) { - // skip "key exchange length" - pointer += 2 - length := int(u16(input[pointer : pointer+2])) - pointer += 2 - if length != 32 { - return nil, fmt.Errorf("key share length should be 32, instead of %v", length) - } - return input[pointer : pointer+length], nil - } - pointer += 2 - length := int(u16(input[pointer : pointer+2])) - pointer += 2 - _ = input[pointer : pointer+length] - pointer += length - } - return nil, errors.New("x25519 does not exist") -} - -// addRecordLayer adds record layer to data -func addRecordLayer(input []byte, typ []byte, ver []byte) []byte { - length := make([]byte, 2) - binary.BigEndian.PutUint16(length, uint16(len(input))) - ret := make([]byte, 5+len(input)) - copy(ret[0:1], typ) - copy(ret[1:3], ver) - copy(ret[3:5], length) - copy(ret[5:], input) - return ret -} - -// parseClientHello parses everything on top of the TLS layer -// (including the record layer) into ClientHello type -func parseClientHello(data []byte) (ret *ClientHello, err error) { - defer func() { - if r := recover(); r != nil { - err = errors.New("Malformed ClientHello") - } - }() - - if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) { - return ret, errors.New("wrong TLS1.3 handshake magic bytes") + return } - peeled := make([]byte, len(data)-5) - copy(peeled, data[5:]) - pointer := 0 - // Handshake Type - handshakeType := peeled[pointer] - if handshakeType != 0x01 { - return ret, errors.New("Not a ClientHello") - } - pointer += 1 - // Length - length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...))) - pointer += 3 - if length != len(peeled[pointer:]) { - return ret, errors.New("Hello length doesn't match") - } - // Client Version - clientVersion := peeled[pointer : pointer+2] - pointer += 2 - // Random - random := peeled[pointer : pointer+32] - pointer += 32 - // Session ID - sessionIdLen := int(peeled[pointer]) - pointer += 1 - sessionId := peeled[pointer : pointer+sessionIdLen] - pointer += sessionIdLen - // Cipher Suites - cipherSuitesLen := int(u16(peeled[pointer : pointer+2])) - pointer += 2 - cipherSuites := peeled[pointer : pointer+cipherSuitesLen] - pointer += cipherSuitesLen - // Compression Methods - compressionMethodsLen := int(peeled[pointer]) - pointer += 1 - compressionMethods := peeled[pointer : pointer+compressionMethodsLen] - pointer += compressionMethodsLen - // Extensions - extensionsLen := int(u16(peeled[pointer : pointer+2])) - pointer += 2 - extensions, err := parseExtensions(peeled[pointer:]) - ret = &ClientHello{ - handshakeType, - length, - clientVersion, - random, - sessionIdLen, - sessionId, - cipherSuitesLen, - cipherSuites, - compressionMethodsLen, - compressionMethods, - extensionsLen, - extensions, - } return } -func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) { - nonce := make([]byte, 12) - rand.Read(nonce) - - encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes - if err != nil { - return nil, err - } - - var serverHello [11][]byte - serverHello[0] = []byte{0x02} // handshake type - serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77 - serverHello[2] = []byte{0x03, 0x03} // server version - serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes - serverHello[4] = []byte{0x20} // session id length 32 - serverHello[5] = sessionId // session id - serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 - serverHello[7] = []byte{0x00} // compression method null - serverHello[8] = []byte{0x00, 0x2e} // extensions length 46 - - keyShare, _ := hex.DecodeString("00330024001d0020") - keyExchange := make([]byte, 32) - copy(keyExchange, encryptedKey[20:48]) - rand.Read(keyExchange[28:32]) - serverHello[9] = append(keyShare, keyExchange...) - - serverHello[10], _ = hex.DecodeString("002b00020304") - var ret []byte - for _, s := range serverHello { - ret = append(ret, s...) - } - return ret, nil -} - -// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages -// together with their respective record layers into one byte slice. -func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) { - TLS12 := []byte{0x03, 0x03} - sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey) - if err != nil { - return nil, err - } - shBytes := addRecordLayer(sh, []byte{0x16}, TLS12) - ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12) - cert := make([]byte, 68) // TODO: add some different lengths maybe? - rand.Read(cert) - encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12) - ret := append(shBytes, ccsBytes...) - ret = append(ret, encryptedCertBytes...) - return ret, nil -} - func unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (ai authenticationInfo, err error) { ephPub, ok := ecdh.Unmarshal(ch.random) if !ok { diff --git a/internal/server/TLSAux.go b/internal/server/TLSAux.go new file mode 100644 index 0000000..779584a --- /dev/null +++ b/internal/server/TLSAux.go @@ -0,0 +1,215 @@ +package server + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "github.com/cbeuw/Cloak/internal/util" +) + +// ClientHello contains every field in a ClientHello message +type ClientHello struct { + handshakeType byte + length int + clientVersion []byte + random []byte + sessionIdLen int + sessionId []byte + cipherSuitesLen int + cipherSuites []byte + compressionMethodsLen int + compressionMethods []byte + extensionsLen int + extensions map[[2]byte][]byte +} + +var u16 = binary.BigEndian.Uint16 +var u32 = binary.BigEndian.Uint32 + +func parseExtensions(input []byte) (ret map[[2]byte][]byte, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.New("Malformed Extensions") + } + }() + pointer := 0 + totalLen := len(input) + ret = make(map[[2]byte][]byte) + for pointer < totalLen { + var typ [2]byte + copy(typ[:], input[pointer:pointer+2]) + pointer += 2 + length := int(u16(input[pointer : pointer+2])) + pointer += 2 + data := input[pointer : pointer+length] + pointer += length + ret[typ] = data + } + return ret, err +} + +func parseKeyShare(input []byte) (ret []byte, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.New("malformed key_share") + } + }() + totalLen := int(u16(input[0:2])) + // 2 bytes "client key share length" + pointer := 2 + for pointer < totalLen { + if bytes.Equal([]byte{0x00, 0x1d}, input[pointer:pointer+2]) { + // skip "key exchange length" + pointer += 2 + length := int(u16(input[pointer : pointer+2])) + pointer += 2 + if length != 32 { + return nil, fmt.Errorf("key share length should be 32, instead of %v", length) + } + return input[pointer : pointer+length], nil + } + pointer += 2 + length := int(u16(input[pointer : pointer+2])) + pointer += 2 + _ = input[pointer : pointer+length] + pointer += length + } + return nil, errors.New("x25519 does not exist") +} + +// addRecordLayer adds record layer to data +func addRecordLayer(input []byte, typ []byte, ver []byte) []byte { + length := make([]byte, 2) + binary.BigEndian.PutUint16(length, uint16(len(input))) + ret := make([]byte, 5+len(input)) + copy(ret[0:1], typ) + copy(ret[1:3], ver) + copy(ret[3:5], length) + copy(ret[5:], input) + return ret +} + +// parseClientHello parses everything on top of the TLS layer +// (including the record layer) into ClientHello type +func parseClientHello(data []byte) (ret *ClientHello, err error) { + defer func() { + if r := recover(); r != nil { + err = errors.New("Malformed ClientHello") + } + }() + + if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) { + return ret, errors.New("wrong TLS1.3 handshake magic bytes") + } + + peeled := make([]byte, len(data)-5) + copy(peeled, data[5:]) + pointer := 0 + // Handshake Type + handshakeType := peeled[pointer] + if handshakeType != 0x01 { + return ret, errors.New("Not a ClientHello") + } + pointer += 1 + // Length + length := int(u32(append([]byte{0x00}, peeled[pointer:pointer+3]...))) + pointer += 3 + if length != len(peeled[pointer:]) { + return ret, errors.New("Hello length doesn't match") + } + // Client Version + clientVersion := peeled[pointer : pointer+2] + pointer += 2 + // Random + random := peeled[pointer : pointer+32] + pointer += 32 + // Session ID + sessionIdLen := int(peeled[pointer]) + pointer += 1 + sessionId := peeled[pointer : pointer+sessionIdLen] + pointer += sessionIdLen + // Cipher Suites + cipherSuitesLen := int(u16(peeled[pointer : pointer+2])) + pointer += 2 + cipherSuites := peeled[pointer : pointer+cipherSuitesLen] + pointer += cipherSuitesLen + // Compression Methods + compressionMethodsLen := int(peeled[pointer]) + pointer += 1 + compressionMethods := peeled[pointer : pointer+compressionMethodsLen] + pointer += compressionMethodsLen + // Extensions + extensionsLen := int(u16(peeled[pointer : pointer+2])) + pointer += 2 + extensions, err := parseExtensions(peeled[pointer:]) + ret = &ClientHello{ + handshakeType, + length, + clientVersion, + random, + sessionIdLen, + sessionId, + cipherSuitesLen, + cipherSuites, + compressionMethodsLen, + compressionMethods, + extensionsLen, + extensions, + } + return +} + +func composeServerHello(sessionId []byte, sharedSecret []byte, sessionKey []byte) ([]byte, error) { + nonce := make([]byte, 12) + rand.Read(nonce) + + encryptedKey, err := util.AESGCMEncrypt(nonce, sharedSecret, sessionKey) // 32 + 16 = 48 bytes + if err != nil { + return nil, err + } + + var serverHello [11][]byte + serverHello[0] = []byte{0x02} // handshake type + serverHello[1] = []byte{0x00, 0x00, 0x76} // length 77 + serverHello[2] = []byte{0x03, 0x03} // server version + serverHello[3] = append(nonce[0:12], encryptedKey[0:20]...) // random 32 bytes + serverHello[4] = []byte{0x20} // session id length 32 + serverHello[5] = sessionId // session id + serverHello[6] = []byte{0xc0, 0x30} // cipher suite TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + serverHello[7] = []byte{0x00} // compression method null + serverHello[8] = []byte{0x00, 0x2e} // extensions length 46 + + keyShare, _ := hex.DecodeString("00330024001d0020") + keyExchange := make([]byte, 32) + copy(keyExchange, encryptedKey[20:48]) + rand.Read(keyExchange[28:32]) + serverHello[9] = append(keyShare, keyExchange...) + + serverHello[10], _ = hex.DecodeString("002b00020304") + var ret []byte + for _, s := range serverHello { + ret = append(ret, s...) + } + return ret, nil +} + +// composeReply composes the ServerHello, ChangeCipherSpec and an ApplicationData messages +// together with their respective record layers into one byte slice. +func composeReply(ch *ClientHello, sharedSecret []byte, sessionKey []byte) ([]byte, error) { + TLS12 := []byte{0x03, 0x03} + sh, err := composeServerHello(ch.sessionId, sharedSecret, sessionKey) + if err != nil { + return nil, err + } + shBytes := addRecordLayer(sh, []byte{0x16}, TLS12) + ccsBytes := addRecordLayer([]byte{0x01}, []byte{0x14}, TLS12) + cert := make([]byte, 68) // TODO: add some different lengths maybe? + rand.Read(cert) + encryptedCertBytes := addRecordLayer(cert, []byte{0x17}, TLS12) + ret := append(shBytes, ccsBytes...) + ret = append(ret, encryptedCertBytes...) + return ret, nil +} diff --git a/internal/server/TLS_test.go b/internal/server/TLSAux_test.go similarity index 100% rename from internal/server/TLS_test.go rename to internal/server/TLSAux_test.go diff --git a/internal/server/auth.go b/internal/server/auth.go index 01b0b22..e2ff1c9 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -1,16 +1,12 @@ package server import ( - "bufio" "bytes" - "crypto/rand" - "encoding/base64" "encoding/binary" "errors" "fmt" "github.com/cbeuw/Cloak/internal/util" "net" - "net/http" "time" log "github.com/sirupsen/logrus" @@ -35,8 +31,6 @@ const ( UNORDERED_FLAG = 0x01 // 0000 0001 ) -var ErrInvalidPubKey = errors.New("public key has invalid format") -var ErrCiphertextLength = errors.New("ciphertext has the wrong length") var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window") var ErrUnreconisedProtocol = errors.New("unreconised protocol") @@ -67,7 +61,6 @@ func touchStone(ai authenticationInfo, now func() time.Time) (info ClientInfo, e return } -var ErrBadClientHello = errors.New("non (or malformed) ClientHello") var ErrReplay = errors.New("duplicate random") var ErrBadProxyMethod = errors.New("invalid proxy method") @@ -76,100 +69,34 @@ var ErrBadProxyMethod = errors.New("invalid proxy method") // is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with // the handshake func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) { - var transport Transport - var ai authenticationInfo switch firstPacket[0] { case 0x47: - transport = WebSocket{} - var req *http.Request - req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(firstPacket))) - if err != nil { - err = fmt.Errorf("failed to parse first HTTP GET: %v", err) - return - } - var hiddenData []byte - hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden")) - - ai, err = unmarshalHidden(hiddenData, sta.staticPv) - if err != nil { - err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err) - return - } - - finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { - handler := newWsHandshakeHandler() - - // For an explanation of the following 3 lines, see the comments in websocket.go - http.Serve(newWsAcceptor(conn, firstPacket), handler) - - <-handler.finished - preparedConn = handler.conn - nonce := make([]byte, 12) - rand.Read(nonce) - - // reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag] - encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes - if err != nil { - err = fmt.Errorf("failed to encrypt reply: %v", err) - return - } - reply := append(nonce, encryptedKey...) - _, err = preparedConn.Write(reply) - if err != nil { - err = fmt.Errorf("failed to write reply: %v", err) - go preparedConn.Close() - return - } - return - } + info.Transport = WebSocket{} case 0x16: - transport = TLS{} - var ch *ClientHello - ch, err = parseClientHello(firstPacket) - if err != nil { - log.Debug(err) - err = ErrBadClientHello - return - } - - if sta.registerRandom(ch.random) { - err = ErrReplay - return - } - - ai, err = unmarshalClientHello(ch, sta.staticPv) - if err != nil { - err = fmt.Errorf("failed to unmarshal ClientHello into authenticationInfo: %v", err) - return - } - - finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { - preparedConn = conn - reply, err := composeReply(ch, ai.sharedSecret, sessionKey) - if err != nil { - err = fmt.Errorf("failed to compose TLS reply: %v", err) - return - } - _, err = preparedConn.Write(reply) - if err != nil { - err = fmt.Errorf("failed to write TLS reply: %v", err) - go preparedConn.Close() - return - } - return - } + info.Transport = TLS{} default: err = ErrUnreconisedProtocol return } + var ai authenticationInfo + ai, finisher, err = info.Transport.handshake(firstPacket, sta.staticPv, conn) + + if err != nil { + return + } + + if sta.registerRandom(ai.nonce) { + err = ErrReplay + return + } + info, err = touchStone(ai, sta.Now) if err != nil { log.Debug(err) - err = fmt.Errorf("transport %v in correct format but not Cloak: %v", transport, err) + err = fmt.Errorf("transport %v in correct format but not Cloak: %v", info.Transport, err) return } - info.Transport = transport if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok { err = ErrBadProxyMethod return diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index cd5900a..7d131ad 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -98,3 +98,38 @@ func TestTouchStone(t *testing.T) { }) } + +func TestPrepareConnection(t *testing.T) { + nineSixSix := func() time.Time { return time.Unix(1565998966, 0) } + sta, _ := InitState(nineSixSix) + pvBytes, _ := hex.DecodeString("10de5a3c4a4d04efafc3e06d1506363a72bd6d053baef123e6a9a79a0c04b547") + p, _ := ecdh.Unmarshal(pvBytes) + sta.staticPv = p.(crypto.PrivateKey) + sta.ProxyBook["shadowsocks"] = nil + + t.Run("TLS correct", func(t *testing.T) { + chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + info, _, err := PrepareConnection(chBytes, sta, nil) + if err != nil { + t.Errorf("failed to get client info: %v", err) + return + } + if info.SessionId != 3710878841 { + t.Error("failed to get correct session id") + return + } + }) + t.Run("TLS correct but replay", func(t *testing.T) { + chBytes, _ := hex.DecodeString("1603010200010001fc0303ac530b5778469dbbc3f9a83c6ac35b63aa6a70c2014026ade30f2faf0266f0242068424f320bcad49b4315a761f9f6dec32b0a403c2d8c0ab337608a694c6e411c0024130113031302c02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a0100018f00000011000f00000c7777772e62696e672e636f6d00170000ff01000100000a000e000c001d00170018001901000101000b00020100002300000010000e000c02683208687474702f312e310005000501000000000033006b0069001d00204655c2c83aaed1db2e89ed17d671fcdc76dc96e36bde8840022f1bda2f31019600170041543af1f8d28b37d984073f40e8361613da502f16e4039f00656f427de0f66480b2e77e3e552e126bb0cc097168f6e5454c7f9501126a2377fb40151f6cfc007e0e002b0009080304030303020301000d0018001604030503060308040805080604010501060102030201002d00020101001c00024001001500920000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") + _, _, err := PrepareConnection(chBytes, sta, nil) + if err != nil { + t.Error("failed to prepare for the first time") + return + } + _, _, err = PrepareConnection(chBytes, sta, nil) + if err != ErrReplay { + t.Errorf("failed to return ErrReplay, got %v instead", err) + return + } + }) +} diff --git a/internal/server/transport.go b/internal/server/transport.go index 8b0f46f..baa9872 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -1,23 +1,16 @@ package server import ( - "github.com/cbeuw/Cloak/internal/util" + "crypto" + "errors" "net" ) type Transport interface { HasRecordLayer() bool UnitReadFunc() func(net.Conn, []byte) (int, error) + handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (authenticationInfo, func([]byte) (net.Conn, error), error) } -type TLS struct{} - -func (TLS) String() string { return "TLS" } -func (TLS) HasRecordLayer() bool { return true } -func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } - -type WebSocket struct{} - -func (WebSocket) String() string { return "WebSocket" } -func (WebSocket) HasRecordLayer() bool { return false } -func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } +var ErrInvalidPubKey = errors.New("public key has invalid format") +var ErrCiphertextLength = errors.New("ciphertext has the wrong length") diff --git a/internal/server/websocket.go b/internal/server/websocket.go index 61b69c7..b3b832d 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -1,142 +1,69 @@ package server import ( + "bufio" + "bytes" "crypto" + "crypto/rand" + "encoding/base64" "errors" "fmt" "github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/util" - "github.com/gorilla/websocket" "net" "net/http" - - log "github.com/sirupsen/logrus" ) -// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous -// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http -// -// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format -// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a -// websocket and eventually wrap the remote Conn as util.WebSocketConn, -// -// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method -// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by -// net/http package upon receiving a request from a Conn. -// -// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should -// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a -// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet -// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that -// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the -// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP -// function. -// -// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then -// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn -// accepted. -// -// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface. -// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to -// Accept will return error (so that the caller won't call again) -// -// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the -// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request -// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do -// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we -// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn. -// -// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a -// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop. -// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then -// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a -// websocket.Conn -// -// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it -// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler -// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside -// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a -// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of -// WsHandshakeHandler can get the reference to the established util.WebSocketConn. -// -// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when -// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel. -// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once -// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished. -// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the -// execution will block until the reference to util.WebSocketConn is ready. +type WebSocket struct{} -// since we need to read the first packet from the client to identify its protocol, the first packet will no longer -// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must -// fake a conn that returns the first packet on first read -type firstBuffedConn struct { - net.Conn - firstRead bool - firstPacket []byte -} +func (WebSocket) String() string { return "WebSocket" } +func (WebSocket) HasRecordLayer() bool { return false } +func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } -func (c *firstBuffedConn) Read(buf []byte) (int, error) { - if !c.firstRead { - c.firstRead = true - copy(buf, c.firstPacket) - n := len(c.firstPacket) - c.firstPacket = []byte{} - return n, nil - } - return c.Conn.Read(buf) -} - -type wsAcceptor struct { - done bool - c *firstBuffedConn -} - -// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an -// http.Server. This is an acceptor that accepts only one Conn -func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor { - f := make([]byte, len(first)) - copy(f, first) - return &wsAcceptor{ - c: &firstBuffedConn{Conn: conn, firstPacket: f}, - } -} - -func (w *wsAcceptor) Accept() (net.Conn, error) { - if w.done { - return nil, errors.New("already accepted") - } - w.done = true - return w.c, nil -} - -func (w *wsAcceptor) Close() error { - w.done = true - return nil -} - -func (w *wsAcceptor) Addr() net.Addr { - return w.c.LocalAddr() -} - -type wsHandshakeHandler struct { - conn net.Conn - finished chan struct{} -} - -// the handler to turn a net.Conn into a websocket.Conn -func newWsHandshakeHandler() *wsHandshakeHandler { - return &wsHandshakeHandler{finished: make(chan struct{})} -} - -func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{} - c, err := upgrader.Upgrade(w, r, nil) +func (WebSocket) handshake(reqPacket []byte, privateKey crypto.PrivateKey, originalConn net.Conn) (ai authenticationInfo, finisher func([]byte) (net.Conn, error), err error) { + var req *http.Request + req, err = http.ReadRequest(bufio.NewReader(bytes.NewBuffer(reqPacket))) if err != nil { - log.Errorf("failed to upgrade connection to ws: %v", err) + err = fmt.Errorf("failed to parse first HTTP GET: %v", err) return } - ws.conn = &util.WebSocketConn{Conn: c} - ws.finished <- struct{}{} + var hiddenData []byte + hiddenData, err = base64.StdEncoding.DecodeString(req.Header.Get("hidden")) + + ai, err = unmarshalHidden(hiddenData, privateKey) + if err != nil { + err = fmt.Errorf("failed to unmarshal hidden data from WS into authenticationInfo: %v", err) + return + } + + finisher = func(sessionKey []byte) (preparedConn net.Conn, err error) { + handler := newWsHandshakeHandler() + + // For an explanation of the following 3 lines, see the comments in websocketAux.go + http.Serve(newWsAcceptor(originalConn, reqPacket), handler) + + <-handler.finished + preparedConn = handler.conn + nonce := make([]byte, 12) + rand.Read(nonce) + + // reply: [12 bytes nonce][32 bytes encrypted session key][16 bytes authentication tag] + encryptedKey, err := util.AESGCMEncrypt(nonce, ai.sharedSecret, sessionKey) // 32 + 16 = 48 bytes + if err != nil { + err = fmt.Errorf("failed to encrypt reply: %v", err) + return + } + reply := append(nonce, encryptedKey...) + _, err = preparedConn.Write(reply) + if err != nil { + err = fmt.Errorf("failed to write reply: %v", err) + go preparedConn.Close() + return + } + return + } + + return } var ErrBadGET = errors.New("non (or malformed) HTTP GET") diff --git a/internal/server/websocketAux.go b/internal/server/websocketAux.go new file mode 100644 index 0000000..5560c4e --- /dev/null +++ b/internal/server/websocketAux.go @@ -0,0 +1,137 @@ +package server + +import ( + "errors" + "github.com/cbeuw/Cloak/internal/util" + "github.com/gorilla/websocket" + "net" + "net/http" + + log "github.com/sirupsen/logrus" +) + +// The code in this file is mostly to obtain a binary-oriented, net.Conn analogous +// util.WebSocketConn from the awkward APIs of gorilla/websocket and net/http +// +// The flow of our process is: accept a Conn from remote, read the first packet remote sent us. If it's in the format +// of a TLS handshake, we hand it over to the TLS part; if it's in the format of a HTTP request, we process it as a +// websocket and eventually wrap the remote Conn as util.WebSocketConn, +// +// To get a util.WebSocketConn, we need a gorilla/websocket.Conn. This is obtained by using upgrader.Upgrade method +// inside a HTTP request handler function (which is defined by us). The HTTP request handler function is invoked by +// net/http package upon receiving a request from a Conn. +// +// Ideally we want to give net/http the connection we got from remote, then it can read the first packet (which should +// be an HTTP request) from that Conn and call the handler function, which can then be upgraded to obtain a +// gorilla/websocket.Conn. But this won't work for two reasons: one is that we have ALREADY READ the request packet +// from the remote Conn to determine if it's TLS or HTTP. When net/http reads from the Conn, it will not receive that +// request packet. The second reason is that there is no API in net/http that accepts a Conn at all. Instead, the +// closest we can get is http.Serve which takes in a net.Listener and a http.Handler which implements the ServeHTTP +// function. +// +// Recall that net.Listener has a method Accept which blocks until the Listener receives a connection, then +// it returns a net.Conn. net/http calls Listener.Accept repeatedly and creates a new goroutine handling each Conn +// accepted. +// +// So here is what we need to do: we need to create a type WsAcceptor that implements net.Listener interface. +// the first time WsAcceptor.Accept is called, it will return something that implements net.Conn, subsequent calls to +// Accept will return error (so that the caller won't call again) +// +// The "something that implements net.Conn" needs to do the following: the first time Read is called, it returns the +// request packet we got from the remote Conn which we have already read, so that the packet, which is an HTTP request +// will be processed by the handling function. Subsequent calls to Read will read directly from the remote Conn. To do +// this we create a type firstBuffedConn that implements net.Conn. When we instantiate a firstBuffedConn object, we +// give it the request packet we have already read from the remote Conn, as well as the reference to the remote Conn. +// +// So now we call http.Serve(WsAcceptor, [some handler]), net/http will call WsAcceptor.Accept, which returns a +// firstBuffedConn. net/http will call WsAcceptor.Accept again but this time it returns error so net/http will stop. +// firstBuffedConn.Read will then be called, which returns the request packet from remote Conn. Then +// [some handler].ServeHTTP will be called, in which websocket.upgrader.Upgrade will be called to obtain a +// websocket.Conn +// +// One problem remains: websocket.upgrader.Upgrade is called inside the handling function. The websocket.Conn it +// returned needs to be somehow preserved so we can keep using it. To do this, we define a type WsHandshakeHandler +// which implements http.Handler. WsHandshakeHandler has a struct field of type net.Conn that can be set. Inside +// WsHandshakeHandler.ServeHTTP, the returned websocket.Conn from upgrader.Upgrade will be converted into a +// util.WebSocketConn, whose reference will be kept in the struct field. Whoever has the reference to the instance of +// WsHandshakeHandler can get the reference to the established util.WebSocketConn. +// +// There is another problem: the call of http.Serve(WsAcceptor, WsHandshakeHandler) is async. We don't know when +// the instance of WsHandshakeHandler will have the util.WebSocketConn ready. We synchronise this using a channel. +// A channel called finished will be provided to an instance of WsHandshakeHandler upon its creation. Once +// WsHandshakeHandler.ServeHTTP has the reference to util.WebSocketConn ready, it will write to finished. +// Outside, immediately after the call to http.Serve(WsAcceptor, WsHandshakeHandler), we read from finished so that the +// execution will block until the reference to util.WebSocketConn is ready. + +// since we need to read the first packet from the client to identify its protocol, the first packet will no longer +// be in Conn's buffer. However, websocket.Upgrade relies on reading the first packet for handshake, so we must +// fake a conn that returns the first packet on first read +type firstBuffedConn struct { + net.Conn + firstRead bool + firstPacket []byte +} + +func (c *firstBuffedConn) Read(buf []byte) (int, error) { + if !c.firstRead { + c.firstRead = true + copy(buf, c.firstPacket) + n := len(c.firstPacket) + c.firstPacket = []byte{} + return n, nil + } + return c.Conn.Read(buf) +} + +type wsAcceptor struct { + done bool + c *firstBuffedConn +} + +// net/http provides no method to serve an existing connection, we must feed in a net.Accept interface to get an +// http.Server. This is an acceptor that accepts only one Conn +func newWsAcceptor(conn net.Conn, first []byte) *wsAcceptor { + f := make([]byte, len(first)) + copy(f, first) + return &wsAcceptor{ + c: &firstBuffedConn{Conn: conn, firstPacket: f}, + } +} + +func (w *wsAcceptor) Accept() (net.Conn, error) { + if w.done { + return nil, errors.New("already accepted") + } + w.done = true + return w.c, nil +} + +func (w *wsAcceptor) Close() error { + w.done = true + return nil +} + +func (w *wsAcceptor) Addr() net.Addr { + return w.c.LocalAddr() +} + +type wsHandshakeHandler struct { + conn net.Conn + finished chan struct{} +} + +// the handler to turn a net.Conn into a websocket.Conn +func newWsHandshakeHandler() *wsHandshakeHandler { + return &wsHandshakeHandler{finished: make(chan struct{})} +} + +func (ws *wsHandshakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Errorf("failed to upgrade connection to ws: %v", err) + return + } + ws.conn = &util.WebSocketConn{Conn: c} + ws.finished <- struct{}{} +}