Make keepalive optional on client -> server and server -> proxy connections. Use KeepAlive value in config (seconds).

This commit is contained in:
notsure2 2019-12-09 18:12:47 +02:00 committed by Andy Wang
parent 6b973045d5
commit 2de034ec92
4 changed files with 22 additions and 6 deletions

View File

@ -23,7 +23,7 @@ import (
var version string var version string
func makeSession(sta *client.State, isAdmin bool) *mux.Session { func makeSession(sta *client.State, isAdmin bool) *mux.Session {
log.Info("Attemtping to start a new session") log.Info("Attempting to start a new session")
if !isAdmin { if !isAdmin {
// sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID. // sessionID is limited to its UID.
@ -32,7 +32,7 @@ func makeSession(sta *client.State, isAdmin bool) *mux.Session {
atomic.StoreUint32(&sta.SessionID, binary.BigEndian.Uint32(quad)) atomic.StoreUint32(&sta.SessionID, binary.BigEndian.Uint32(quad))
} }
d := net.Dialer{Control: protector} d := net.Dialer{Control: protector, KeepAlive: sta.KeepAlive}
connsCh := make(chan net.Conn, sta.NumConn) connsCh := make(chan net.Conn, sta.NumConn)
var _sessionKey atomic.Value var _sessionKey atomic.Value
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -174,7 +174,8 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
} }
proxyAddr := sta.ProxyBook[ci.ProxyMethod] proxyAddr := sta.ProxyBook[ci.ProxyMethod]
localConn, err := net.Dial(proxyAddr.Network(), proxyAddr.String()) d := net.Dialer{KeepAlive: sta.KeepAlive}
localConn, err := d.Dial(proxyAddr.Network(), proxyAddr.String())
if err != nil { if err != nil {
log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err) log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
user.CloseSession(ci.SessionId, "Failed to connect to proxy server") user.CloseSession(ci.SessionId, "Failed to connect to proxy server")

View File

@ -24,6 +24,7 @@ type rawConfig struct {
Transport string Transport string
NumConn int NumConn int
StreamTimeout int StreamTimeout int
KeepAlive int
RemoteHost string RemoteHost string
RemotePort int RemotePort int
} }
@ -50,6 +51,7 @@ type State struct {
ServerName string ServerName string
NumConn int NumConn int
Timeout time.Duration Timeout time.Duration
KeepAlive time.Duration
} }
// semi-colon separated value. This is for Android plugin options // semi-colon separated value. This is for Android plugin options
@ -138,6 +140,11 @@ func (sta *State) ParseConfig(conf string) (err error) {
} else { } else {
sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
} }
if preParse.KeepAlive <= 0 {
sta.KeepAlive = -1
} else {
sta.KeepAlive = time.Duration(preParse.KeepAlive) * time.Second
}
sta.UID = preParse.UID sta.UID = preParse.UID
pub, ok := ecdh.Unmarshal(preParse.PublicKey) pub, ok := ecdh.Unmarshal(preParse.PublicKey)

View File

@ -24,6 +24,7 @@ type rawConfig struct {
AdminUID []byte AdminUID []byte
DatabasePath string DatabasePath string
StreamTimeout int StreamTimeout int
KeepAlive int
CncMode bool CncMode bool
} }
@ -32,9 +33,10 @@ type State struct {
BindAddr []net.Addr BindAddr []net.Addr
ProxyBook map[string]net.Addr ProxyBook map[string]net.Addr
Now func() time.Time Now func() time.Time
AdminUID []byte AdminUID []byte
Timeout time.Duration Timeout time.Duration
KeepAlive time.Duration
BypassUID map[[16]byte]struct{} BypassUID map[[16]byte]struct{}
staticPv crypto.PrivateKey staticPv crypto.PrivateKey
@ -173,6 +175,12 @@ func (sta *State) ParseConfig(conf string) (err error) {
sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second
} }
if preParse.KeepAlive <= 0 {
sta.KeepAlive = -1
} else {
sta.KeepAlive = time.Duration(preParse.KeepAlive) * time.Second
}
sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr) sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse RedirAddr: %v", err) return fmt.Errorf("unable to parse RedirAddr: %v", err)