diff --git a/internal/client/piper.go b/internal/client/piper.go index c867db5..4bb9fc0 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -18,7 +18,8 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration log.Fatal(err) } - var streams sync.Map + streams := make(map[string]*mux.Stream) + var streamsMutex sync.Mutex data := make([]byte, 8192) for { @@ -32,25 +33,27 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration sesh = newSeshFunc() } - var stream *mux.Stream - streamObj, ok := streams.Load(addr.String()) + streamsMutex.Lock() + stream, ok := streams[addr.String()] if !ok { if singleplex { sesh = newSeshFunc() } stream, err = sesh.OpenStream() - streamObj = stream if err != nil { if singleplex { sesh.Close() } log.Errorf("Failed to open stream: %v", err) + streamsMutex.Unlock() continue } + streams[addr.String()] = stream + streamsMutex.Unlock() + _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) - streams.Store(addr.String(), stream) proxyAddr := addr go func(stream *mux.Stream, localConn *net.UDPConn) { buf := make([]byte, 8192) @@ -58,28 +61,32 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration n, err := stream.Read(buf) if err != nil { log.Tracef("copying stream to proxy client: %v", err) - streams.Delete(addr.String()) - stream.Close() - return + break } _ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) _, err = localConn.WriteTo(buf[:n], proxyAddr) if err != nil { log.Tracef("copying stream to proxy client: %v", err) - streams.Delete(addr.String()) - stream.Close() - return + break } } + streamsMutex.Lock() + delete(streams, addr.String()) + streamsMutex.Unlock() + stream.Close() + return }(stream, localConn) + } else { + streamsMutex.Unlock() } - stream = streamObj.(*mux.Stream) _, err = stream.Write(data[:i]) if err != nil { log.Tracef("copying proxy client to stream: %v", err) - streams.Delete(addr.String()) + streamsMutex.Lock() + delete(streams, addr.String()) + streamsMutex.Unlock() stream.Close() continue } diff --git a/internal/common/tls.go b/internal/common/tls.go index 4917e8e..fd2fce4 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "encoding/binary" "io" "net" @@ -36,18 +37,15 @@ func AddRecordLayer(input []byte, typ byte, ver uint16) []byte { type TLSConn struct { net.Conn - writeM sync.Mutex - writeBuf []byte + writeBufPool sync.Pool } func NewTLSConn(conn net.Conn) *TLSConn { - writeBuf := make([]byte, initialWriteBufSize) - writeBuf[0] = ApplicationData - writeBuf[1] = byte(VersionTLS13 >> 8) - writeBuf[2] = byte(VersionTLS13 & 0xFF) return &TLSConn{ - Conn: conn, - writeBuf: writeBuf, + Conn: conn, + writeBufPool: sync.Pool{New: func() interface{} { + return new(bytes.Buffer) + }}, } } @@ -95,13 +93,16 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) - tls.writeM.Lock() - tls.writeBuf = append(tls.writeBuf[:5], in...) - tls.writeBuf[3] = byte(msgLen >> 8) - tls.writeBuf[4] = byte(msgLen & 0xFF) - n, err = tls.Conn.Write(tls.writeBuf[:recordLayerLength+msgLen]) - tls.writeM.Unlock() - return n - recordLayerLength, err + writeBuf := tls.writeBufPool.Get().(*bytes.Buffer) + writeBuf.WriteByte(ApplicationData) + writeBuf.WriteByte(byte(VersionTLS13 >> 8)) + writeBuf.WriteByte(byte(VersionTLS13 & 0xFF)) + writeBuf.WriteByte(byte(msgLen >> 8)) + writeBuf.WriteByte(byte(msgLen & 0xFF)) + writeBuf.Write(in) + i, err := writeBuf.WriteTo(tls.Conn) + tls.writeBufPool.Put(writeBuf) + return int(i - recordLayerLength), err } func (tls *TLSConn) Close() error {