Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2020-12-23 02:51:21 +02:00
commit 03173c71e5
2 changed files with 36 additions and 28 deletions

View File

@ -18,7 +18,8 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
log.Fatal(err) log.Fatal(err)
} }
var streams sync.Map streams := make(map[string]*mux.Stream)
var streamsMutex sync.Mutex
data := make([]byte, 8192) data := make([]byte, 8192)
for { for {
@ -32,25 +33,27 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
sesh = newSeshFunc() sesh = newSeshFunc()
} }
var stream *mux.Stream streamsMutex.Lock()
streamObj, ok := streams.Load(addr.String()) stream, ok := streams[addr.String()]
if !ok { if !ok {
if singleplex { if singleplex {
sesh = newSeshFunc() sesh = newSeshFunc()
} }
stream, err = sesh.OpenStream() stream, err = sesh.OpenStream()
streamObj = stream
if err != nil { if err != nil {
if singleplex { if singleplex {
sesh.Close() sesh.Close()
} }
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
streamsMutex.Unlock()
continue continue
} }
streams[addr.String()] = stream
streamsMutex.Unlock()
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) _ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
streams.Store(addr.String(), stream)
proxyAddr := addr proxyAddr := addr
go func(stream *mux.Stream, localConn *net.UDPConn) { go func(stream *mux.Stream, localConn *net.UDPConn) {
buf := make([]byte, 8192) buf := make([]byte, 8192)
@ -58,28 +61,32 @@ func RouteUDP(bindFunc func() (*net.UDPConn, error), streamTimeout time.Duration
n, err := stream.Read(buf) n, err := stream.Read(buf)
if err != nil { if err != nil {
log.Tracef("copying stream to proxy client: %v", err) log.Tracef("copying stream to proxy client: %v", err)
streams.Delete(addr.String()) break
stream.Close()
return
} }
_ = stream.SetReadDeadline(time.Now().Add(streamTimeout)) _ = stream.SetReadDeadline(time.Now().Add(streamTimeout))
_, err = localConn.WriteTo(buf[:n], proxyAddr) _, err = localConn.WriteTo(buf[:n], proxyAddr)
if err != nil { if err != nil {
log.Tracef("copying stream to proxy client: %v", err) log.Tracef("copying stream to proxy client: %v", err)
streams.Delete(addr.String()) break
}
}
streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close() stream.Close()
return return
}
}
}(stream, localConn) }(stream, localConn)
} else {
streamsMutex.Unlock()
} }
stream = streamObj.(*mux.Stream)
_, err = stream.Write(data[:i]) _, err = stream.Write(data[:i])
if err != nil { if err != nil {
log.Tracef("copying proxy client to stream: %v", err) log.Tracef("copying proxy client to stream: %v", err)
streams.Delete(addr.String()) streamsMutex.Lock()
delete(streams, addr.String())
streamsMutex.Unlock()
stream.Close() stream.Close()
continue continue
} }

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
@ -36,18 +37,15 @@ func AddRecordLayer(input []byte, typ byte, ver uint16) []byte {
type TLSConn struct { type TLSConn struct {
net.Conn net.Conn
writeM sync.Mutex writeBufPool sync.Pool
writeBuf []byte
} }
func NewTLSConn(conn net.Conn) *TLSConn { func NewTLSConn(conn net.Conn) *TLSConn {
writeBuf := make([]byte, initialWriteBufSize)
writeBuf[0] = ApplicationData
writeBuf[1] = byte(VersionTLS13 >> 8)
writeBuf[2] = byte(VersionTLS13 & 0xFF)
return &TLSConn{ return &TLSConn{
Conn: conn, Conn: conn,
writeBuf: writeBuf, writeBufPool: sync.Pool{New: func() interface{} {
return new(bytes.Buffer)
}},
} }
} }
@ -95,13 +93,16 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
func (tls *TLSConn) Write(in []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in) msgLen := len(in)
tls.writeM.Lock() writeBuf := tls.writeBufPool.Get().(*bytes.Buffer)
tls.writeBuf = append(tls.writeBuf[:5], in...) writeBuf.WriteByte(ApplicationData)
tls.writeBuf[3] = byte(msgLen >> 8) writeBuf.WriteByte(byte(VersionTLS13 >> 8))
tls.writeBuf[4] = byte(msgLen & 0xFF) writeBuf.WriteByte(byte(VersionTLS13 & 0xFF))
n, err = tls.Conn.Write(tls.writeBuf[:recordLayerLength+msgLen]) writeBuf.WriteByte(byte(msgLen >> 8))
tls.writeM.Unlock() writeBuf.WriteByte(byte(msgLen & 0xFF))
return n - recordLayerLength, err writeBuf.Write(in)
i, err := writeBuf.WriteTo(tls.Conn)
tls.writeBufPool.Put(writeBuf)
return int(i - recordLayerLength), err
} }
func (tls *TLSConn) Close() error { func (tls *TLSConn) Close() error {