diff --git a/internal/common/tls.go b/internal/common/tls.go index fb54e97..3953992 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -2,6 +2,7 @@ package common import ( "encoding/binary" + "errors" "io" "net" "sync" @@ -94,6 +95,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) + if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2 + return 0, errors.New("message is too long") + } writeBuf := tls.writeBufPool.Get().(*[]byte) *writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF)) *writeBuf = append(*writeBuf, in...) diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 2dad728..76ba3bc 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -138,7 +138,7 @@ func BenchmarkObfs(b *testing.B) { testPayload, } - obfsBuf := make([]byte, defaultSendRecvBufSize) + obfsBuf := make([]byte, len(testPayload)*2) var key [32]byte rand.Read(key[:]) @@ -211,7 +211,7 @@ func BenchmarkDeobfs(b *testing.B) { testPayload, } - obfsBuf := make([]byte, defaultSendRecvBufSize) + obfsBuf := make([]byte, len(testPayload)*2) var key [32]byte rand.Read(key[:]) diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 63f1f6f..a8aed87 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -25,4 +25,4 @@ type recvBuffer interface { // size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks. // If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write, // a panic will happen. -const recvBufferSizeLimit = defaultSendRecvBufSize << 12 +const recvBufferSizeLimit = 1 << 31 diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6abc90e..b9d8540 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -13,10 +13,9 @@ import ( ) const ( - acceptBacklog = 1024 - // TODO: will this be a signature? - defaultSendRecvBufSize = 20480 + acceptBacklog = 1024 defaultInactivityTimeout = 30 * time.Second + defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2 ) var ErrBrokenSession = errors.New("broken session") @@ -40,12 +39,6 @@ type SessionConfig struct { // maximum size of an obfuscated frame, including headers and overhead MsgOnWireSizeLimit int - // StreamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf) - StreamSendBufferSize int - // ConnReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in - // switchboard.deplex) - ConnReceiveBufferSize int - // InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself InactivityTimeout time.Duration } @@ -87,6 +80,11 @@ type Session struct { // the max size passed to Write calls before it splits it into multiple frames // i.e. the max size a piece of data can fit into a Frame.Payload maxStreamUnitWrite int + // streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf) + streamSendBufferSize int + // connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in + // switchboard.deplex) + connReceiveBufferSize int } func MakeSession(id uint32, config SessionConfig) *Session { @@ -103,23 +101,19 @@ func MakeSession(id uint32, config SessionConfig) *Session { if config.Valve == nil { sesh.Valve = UNLIMITED_VALVE } - if config.StreamSendBufferSize <= 0 { - sesh.StreamSendBufferSize = defaultSendRecvBufSize - } - if config.ConnReceiveBufferSize <= 0 { - sesh.ConnReceiveBufferSize = defaultSendRecvBufSize - } if config.MsgOnWireSizeLimit <= 0 { - sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024 + sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize } if config.InactivityTimeout == 0 { sesh.InactivityTimeout = defaultInactivityTimeout } - // todo: validation. this must be smaller than StreamSendBufferSize - sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead + + sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.maxOverhead + sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit + sesh.connReceiveBufferSize = 20480 // for backwards compatibility sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { - b := make([]byte, sesh.StreamSendBufferSize) + b := make([]byte, sesh.streamSendBufferSize) return &b }} diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 84e43c9..829d944 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -159,7 +159,7 @@ func (sb *switchboard) closeAll() { // deplex function costantly reads from a TCP connection func (sb *switchboard) deplex(connId uint32, conn net.Conn) { defer conn.Close() - buf := make([]byte, sb.session.ConnReceiveBufferSize) + buf := make([]byte, sb.session.connReceiveBufferSize) for { n, err := conn.Read(buf) sb.valve.rxWait(n) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index e4072e0..187507f 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -321,7 +321,7 @@ func TestTCPSingleplex(t *testing.T) { t.Fatal(err) } - const echoMsgLen = 16384 + const echoMsgLen = 1 << 16 go serveTCPEcho(proxyFromCkServerL) proxyConn1, err := proxyToCkClientD.Dial("", "")