diff --git a/internal/common/tls.go b/internal/common/tls.go index 9af2c3c..30af801 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -16,25 +16,27 @@ const ( Handshake = 22 ApplicationData = 23 + + initialWriteBufSize = 14336 ) -func AddRecordLayer(input []byte, typ byte, ver uint16) []byte { - msgLen := len(input) - retLen := msgLen + recordLayerLength - var ret []byte - if cap(input) >= retLen { - ret = input[:retLen] - } else { - ret = make([]byte, retLen) - } - copy(ret[recordLayerLength:], input) - ret[0] = typ - ret[1] = byte(ver >> 8) - ret[2] = byte(ver) - ret[3] = byte(msgLen >> 8) - ret[4] = byte(msgLen) - return ret -} +//func AddRecordLayer(input []byte, typ byte, ver uint16) []byte { +// msgLen := len(input) +// retLen := msgLen + recordLayerLength +// var ret []byte +// if cap(input) >= retLen { +// ret = input[:retLen] +// } else { +// ret = make([]byte, retLen) +// } +// copy(ret[recordLayerLength:], input) +// ret[0] = typ +// ret[1] = byte(ver >> 8) +// ret[2] = byte(ver) +// ret[3] = byte(msgLen >> 8) +// ret[4] = byte(msgLen) +// return ret +//} type TLSConn struct { net.Conn @@ -43,9 +45,13 @@ type TLSConn struct { } func NewTLSConn(conn net.Conn) *TLSConn { + writeBuf := make([]byte, initialWriteBufSize) + writeBuf[0] = ApplicationData + writeBuf[1] = byte(VersionTLS13 >> 8) + writeBuf[2] = byte(VersionTLS13 & 0xFF) return &TLSConn{ Conn: conn, - writeBuf: make([]byte, 15000), + writeBuf: writeBuf, } } @@ -74,6 +80,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { // a single message can also be segmented due to MTU of the IP layer. // This function guareentees a single TLS message to be read and everything // else is left in the buffer. + if len(buffer) < recordLayerLength { + return 0, io.ErrShortBuffer + } _, err = io.ReadFull(tls.Conn, buffer[:recordLayerLength]) if err != nil { return @@ -92,9 +101,6 @@ func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) tls.writeM.Lock() tls.writeBuf = append(tls.writeBuf[:5], in...) - tls.writeBuf[0] = ApplicationData - tls.writeBuf[1] = byte(VersionTLS13 >> 8) - tls.writeBuf[2] = byte(VersionTLS13 & 0xFF) tls.writeBuf[3] = byte(msgLen >> 8) tls.writeBuf[4] = byte(msgLen & 0xFF) n, err = tls.Conn.Write(tls.writeBuf[:recordLayerLength+msgLen]) diff --git a/internal/common/tls_test.go b/internal/common/tls_test.go new file mode 100644 index 0000000..3f8ce93 --- /dev/null +++ b/internal/common/tls_test.go @@ -0,0 +1,40 @@ +package common + +import ( + "net" + "testing" +) + +func BenchmarkTLSConn_Write(b *testing.B) { + const bufSize = 16 * 1024 + addrCh := make(chan string, 1) + go func() { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatal(err) + } + addrCh <- listener.Addr().String() + conn, err := listener.Accept() + if err != nil { + b.Fatal(err) + } + readBuf := make([]byte, bufSize*2) + for { + _, err = conn.Read(readBuf) + if err != nil { + return + } + } + }() + data := make([]byte, bufSize) + discardConn, _ := net.Dial("tcp", <-addrCh) + tlsConn := NewTLSConn(discardConn) + defer tlsConn.Close() + b.SetBytes(bufSize) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + tlsConn.Write(data) + } + }) +}