Add Stream Timeout

This commit is contained in:
Andy Wang 2019-08-19 23:23:41 +01:00
parent ba467e8a32
commit eabe113547
8 changed files with 104 additions and 82 deletions

View File

@ -97,82 +97,86 @@ func routeUDP(sta *client.State, adminUID []byte) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
start:
localConn, err := net.ListenUDP("udp", localUDPAddr) localConn, err := net.ListenUDP("udp", localUDPAddr)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
for { var otherEnd atomic.Value
var otherEnd atomic.Value data := make([]byte, 10240)
data := make([]byte, 10240) i, oe, err := localConn.ReadFromUDP(data)
i, oe, err := localConn.ReadFromUDP(data) if err != nil {
if err != nil { log.Errorf("Failed to read first packet from proxy client: %v", err)
log.Errorf("Failed to read first packet from proxy client: %v", err) localConn.Close()
localConn.Close() return
return }
} otherEnd.Store(oe)
otherEnd.Store(oe)
if sesh == nil || sesh.IsClosed() { if sesh == nil || sesh.IsClosed() {
sesh = makeSession(sta, adminUID != nil, true) sesh = makeSession(sta, adminUID != nil, true)
} }
log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String()) log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String())
stream, err := sesh.OpenStream() stream, err := sesh.OpenStream()
if err != nil { if err != nil {
log.Errorf("Failed to open stream: %v", err) log.Errorf("Failed to open stream: %v", err)
localConn.Close() localConn.Close()
//localConnWrite.Close() //localConnWrite.Close()
return return
} }
_, err = stream.Write(data[:i]) _, err = stream.Write(data[:i])
if err != nil { if err != nil {
log.Errorf("Failed to write to stream: %v", err) log.Errorf("Failed to write to stream: %v", err)
localConn.Close() localConn.Close()
//localConnWrite.Close() //localConnWrite.Close()
stream.Close() stream.Close()
return return
} }
// stream to proxy // stream to proxy
go func() { go func() {
buf := make([]byte, 16380)
for {
i, err := io.ReadAtLeast(stream, buf, 1)
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
i, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr))
if err != nil {
log.Print(err)
go localConn.Close()
go stream.Close()
return
}
}
}()
// proxy to stream
buf := make([]byte, 16380) buf := make([]byte, 16380)
for { for {
i, oe, err := localConn.ReadFromUDP(buf) i, err := io.ReadAtLeast(stream, buf, 1)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
go localConn.Close() localConn.Close()
go stream.Close() stream.Close()
return break
} }
otherEnd.Store(oe) i, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr))
i, err = stream.Write(buf[:i])
if err != nil { if err != nil {
log.Print(err) log.Print(err)
go localConn.Close() localConn.Close()
go stream.Close() stream.Close()
return break
} }
} }
}()
// proxy to stream
buf := make([]byte, 16380)
if sta.Timeout != 0 {
localConn.SetReadDeadline(time.Now().Add(sta.Timeout))
} }
for {
if sta.Timeout != 0 {
localConn.SetReadDeadline(time.Now().Add(sta.Timeout))
}
i, oe, err := localConn.ReadFromUDP(buf)
if err != nil {
localConn.Close()
stream.Close()
break
}
otherEnd.Store(oe)
i, err = stream.Write(buf[:i])
if err != nil {
localConn.Close()
stream.Close()
break
}
}
goto start
} }
@ -212,8 +216,8 @@ func routeTCP(sta *client.State, adminUID []byte) {
stream.Close() stream.Close()
return return
} }
go util.Pipe(localConn, stream) go util.Pipe(localConn, stream, 0)
util.Pipe(stream, localConn) util.Pipe(stream, localConn, sta.Timeout)
}() }()
} }

View File

@ -48,8 +48,8 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
if err != nil { if err != nil {
log.Error("Failed to send first packet to redirection server", err) log.Error("Failed to send first packet to redirection server", err)
} }
go util.Pipe(webConn, conn) go util.Pipe(webConn, conn, 0)
go util.Pipe(conn, webConn) go util.Pipe(conn, webConn, 0)
} }
ci, finishHandshake, err := server.PrepareConnection(data, sta, conn) ci, finishHandshake, err := server.PrepareConnection(data, sta, conn)
@ -177,8 +177,8 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
log.Debugf("%v endpoint has been successfully connected", ci.ProxyMethod) log.Debugf("%v endpoint has been successfully connected", ci.ProxyMethod)
go util.Pipe(localConn, newStream) go util.Pipe(localConn, newStream, 0)
go util.Pipe(newStream, localConn) go util.Pipe(newStream, localConn, sta.Timeout)
} }

View File

@ -5,5 +5,6 @@
"PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=", "PublicKey":"IYoUzkle/T/kriE+Ufdm7AHQtIeGnBWbhhlTbmDpUUI=",
"ServerName":"www.bing.com", "ServerName":"www.bing.com",
"NumConn":4, "NumConn":4,
"BrowserSig":"chrome" "BrowserSig": "chrome",
"StreamTimeout": 300
} }

View File

@ -19,5 +19,6 @@
"RedirAddr": "204.79.197.200:443", "RedirAddr": "204.79.197.200:443",
"PrivateKey": "EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=", "PrivateKey": "EN5aPEpNBO+vw+BtFQY2OnK9bQU7rvEj5qmnmgwEtUc=",
"AdminUID": "5nneblJy6lniPJfr81LuYQ==", "AdminUID": "5nneblJy6lniPJfr81LuYQ==",
"DatabasePath": "userinfo.db" "DatabasePath": "userinfo.db",
"StreamTimeout": 300
} }

View File

@ -20,6 +20,7 @@ type rawConfig struct {
PublicKey string PublicKey string
BrowserSig string BrowserSig string
NumConn int NumConn int
StreamTimeout int
} }
// State stores global variables // State stores global variables
@ -41,6 +42,7 @@ type State struct {
EncryptionMethod byte EncryptionMethod byte
ServerName string ServerName string
NumConn int NumConn int
Timeout time.Duration
} }
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State { func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State {
@ -73,7 +75,7 @@ func ssvToJson(ssv string) (ret []byte) {
value := sp[1] value := sp[1]
// JSON doesn't like quotation marks around int // JSON doesn't like quotation marks around int
// Yes this is extremely ugly but it's still better than writing a tokeniser // Yes this is extremely ugly but it's still better than writing a tokeniser
if key == "NumConn" || key == "Unordered" { if key == "NumConn" || key == "Unordered" || key == "StreamTimeout" {
ret = append(ret, []byte(`"`+key+`":`+value+`,`)...) ret = append(ret, []byte(`"`+key+`":`+value+`,`)...)
} else { } else {
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...) ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)
@ -124,6 +126,7 @@ func (sta *State) ParseConfig(conf string) (err error) {
sta.ProxyMethod = preParse.ProxyMethod sta.ProxyMethod = preParse.ProxyMethod
sta.ServerName = preParse.ServerName sta.ServerName = preParse.ServerName
sta.NumConn = preParse.NumConn sta.NumConn = preParse.NumConn
sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
uid, err := base64.StdEncoding.DecodeString(preParse.UID) uid, err := base64.StdEncoding.DecodeString(preParse.UID)
if err != nil { if err != nil {

View File

@ -187,7 +187,10 @@ func (sesh *Session) TerminalMsg() string {
func (sesh *Session) Close() error { func (sesh *Session) Close() error {
log.Debugf("attempting to close session %v", sesh.id) log.Debugf("attempting to close session %v", sesh.id)
atomic.StoreUint32(&sesh.closed, 1) if atomic.SwapUint32(&sesh.closed, 1) == 1 {
log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing
}
sesh.streamsM.Lock() sesh.streamsM.Lock()
sesh.acceptCh <- nil sesh.acceptCh <- nil
for id, stream := range sesh.streams { for id, stream := range sesh.streams {

View File

@ -17,13 +17,14 @@ import (
) )
type rawConfig struct { type rawConfig struct {
ProxyBook map[string][]string ProxyBook map[string][]string
BypassUID [][]byte BypassUID [][]byte
RedirAddr string RedirAddr string
PrivateKey string PrivateKey string
AdminUID string AdminUID string
DatabasePath string DatabasePath string
CncMode bool StreamTimeout int
CncMode bool
} }
// State type stores the global state of the program // State type stores the global state of the program
@ -35,6 +36,7 @@ type State struct {
Now func() time.Time Now func() time.Time
AdminUID []byte AdminUID []byte
Timeout time.Duration
BypassUID map[[16]byte]struct{} BypassUID map[[16]byte]struct{}
staticPv crypto.PrivateKey staticPv crypto.PrivateKey
@ -92,6 +94,7 @@ func (sta *State) ParseConfig(conf string) (err error) {
} }
sta.RedirAddr = preParse.RedirAddr sta.RedirAddr = preParse.RedirAddr
sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
for name, pair := range preParse.ProxyBook { for name, pair := range preParse.ProxyBook {
if len(pair) != 2 { if len(pair) != 2 {

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net" "net"
"strconv" "strconv"
"time"
) )
func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) { func AESGCMEncrypt(nonce []byte, key []byte, plaintext []byte) ([]byte, error) {
@ -86,22 +87,28 @@ func AddRecordLayer(input []byte, typ []byte, ver []byte) []byte {
return ret return ret
} }
func Pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) { func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) {
// The maximum size of TLS message will be 16380+12+16. 12 because of the stream header and 16 // The maximum size of TLS message will be 16380+12+16. 12 because of the stream header and 16
// because of the salt/mac // because of the salt/mac
// 16408 is the max TLS message size on Firefox // 16408 is the max TLS message size on Firefox
buf := make([]byte, 16380) buf := make([]byte, 16380)
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
}
for { for {
if srcReadTimeout != 0 {
src.SetReadDeadline(time.Now().Add(srcReadTimeout))
}
i, err := io.ReadAtLeast(src, buf, 1) i, err := io.ReadAtLeast(src, buf, 1)
if err != nil { if err != nil {
go dst.Close() dst.Close()
go src.Close() src.Close()
return return
} }
i, err = dst.Write(buf[:i]) i, err = dst.Write(buf[:i])
if err != nil { if err != nil {
go dst.Close() dst.Close()
go src.Close() src.Close()
return return
} }
} }