diff --git a/internal/server/TLS.go b/internal/server/TLS.go index d00f466..10b408d 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -104,6 +104,10 @@ func parseClientHello(data []byte) (ret *ClientHello, err error) { } }() + if !bytes.Equal(data[0:3], []byte{0x16, 0x03, 0x01}) { + return ret, errors.New("wrong TLS handshake magic bytes") + } + peeled := make([]byte, len(data)-5) copy(peeled, data[5:]) pointer := 0 diff --git a/internal/server/TLS_test.go b/internal/server/TLS_test.go index ef70265..a58ccd9 100644 --- a/internal/server/TLS_test.go +++ b/internal/server/TLS_test.go @@ -12,9 +12,11 @@ func TestParseClientHello(t *testing.T) { ch, err := parseClientHello(chBytes) if err != nil { t.Errorf("Expecting no error, got %v", err) + return } if !bytes.Equal(ch.clientVersion, []byte{0x03, 0x03}) { t.Errorf("expecting client version 0x0303, got %v", ch.clientVersion) + return } }) t.Run("Malformed ClientHello", func(t *testing.T) { @@ -22,6 +24,7 @@ func TestParseClientHello(t *testing.T) { _, err := parseClientHello(chBytes) if err == nil { t.Error("expecting Malformed ClientHello, got no error") + return } }) t.Run("not Handshake", func(t *testing.T) { @@ -29,6 +32,7 @@ func TestParseClientHello(t *testing.T) { _, err := parseClientHello(chBytes) if err == nil { t.Error("not a tls handshake, got no error") + return } }) t.Run("wrong version", func(t *testing.T) { @@ -36,6 +40,7 @@ func TestParseClientHello(t *testing.T) { _, err := parseClientHello(chBytes) if err == nil { t.Error("wrong version, got no error") + return } }) }