Refactor return value of decryption

This commit is contained in:
Qian Wang 2019-08-12 14:21:42 +01:00
parent 71e48a1947
commit 58cbb73f0f
3 changed files with 40 additions and 29 deletions

View File

@ -52,14 +52,14 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
go util.Pipe(conn, webConn) go util.Pipe(conn, webConn)
} }
UID, sessionID, proxyMethod, encryptionMethod, finishHandshake, err := server.PrepareConnection(data, sta, conn) ci, finishHandshake, err := server.PrepareConnection(data, sta, conn)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"remoteAddr": remoteAddr, "remoteAddr": remoteAddr,
"UID": b64(UID), "UID": b64(ci.UID),
"sessionId": sessionID, "sessionId": ci.SessionId,
"proxyMethod": proxyMethod, "proxyMethod": ci.ProxyMethod,
"encryptionMethod": encryptionMethod, "encryptionMethod": ci.EncryptionMethod,
}).Warn(err) }).Warn(err)
goWeb() goWeb()
return return
@ -67,7 +67,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
sessionKey := make([]byte, 32) sessionKey := make([]byte, 32)
rand.Read(sessionKey) rand.Read(sessionKey)
obfuscator, err := mux.GenerateObfs(encryptionMethod, sessionKey) obfuscator, err := mux.GenerateObfs(ci.EncryptionMethod, sessionKey)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
goWeb() goWeb()
@ -77,7 +77,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not // adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode // added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode // and normal proxy mode is that sessionID needs == 0 for admin mode
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 { if bytes.Equal(ci.UID, sta.AdminUID) && ci.SessionId == 0 {
err = finishHandshake(sessionKey) err = finishHandshake(sessionKey)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -101,14 +101,14 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
} }
var user *server.ActiveUser var user *server.ActiveUser
if sta.IsBypass(UID) { if sta.IsBypass(ci.UID) {
user, err = sta.Panel.GetBypassUser(UID) user, err = sta.Panel.GetBypassUser(ci.UID)
} else { } else {
user, err = sta.Panel.GetUser(UID) user, err = sta.Panel.GetUser(ci.UID)
} }
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"UID": b64(UID), "UID": b64(ci.UID),
"remoteAddr": remoteAddr, "remoteAddr": remoteAddr,
"error": err, "error": err,
}).Warn("+1 unauthorised UID") }).Warn("+1 unauthorised UID")
@ -121,9 +121,9 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
Valve: nil, Valve: nil,
UnitRead: util.ReadTLS, UnitRead: util.ReadTLS,
} }
sesh, existing, err := user.GetSession(sessionID, seshConfig) sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)
if err != nil { if err != nil {
user.DeleteSession(sessionID, "") user.DeleteSession(ci.SessionId, "")
log.Error(err) log.Error(err)
return return
} }
@ -147,8 +147,8 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
log.Trace("finished handshake") log.Trace("finished handshake")
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"UID": b64(UID), "UID": b64(ci.UID),
"sessionID": sessionID, "sessionID": ci.SessionId,
}).Info("New session") }).Info("New session")
sesh.AddConnection(conn) sesh.AddConnection(conn)
@ -157,20 +157,20 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
if err != nil { if err != nil {
if err == mux.ErrBrokenSession { if err == mux.ErrBrokenSession {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"UID": b64(UID), "UID": b64(ci.UID),
"sessionID": sessionID, "sessionID": ci.SessionId,
"reason": sesh.TerminalMsg(), "reason": sesh.TerminalMsg(),
}).Info("Session closed") }).Info("Session closed")
user.DeleteSession(sessionID, "") user.DeleteSession(ci.SessionId, "")
return return
} else { } else {
continue continue
} }
} }
localConn, err := net.Dial("tcp", sta.ProxyBook[proxyMethod]) localConn, err := net.Dial("tcp", sta.ProxyBook[ci.ProxyMethod])
if err != nil { if err != nil {
log.Errorf("Failed to connect to %v: %v", proxyMethod, err) log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err)
user.DeleteSession(sessionID, "Failed to connect to proxy server") user.DeleteSession(ci.SessionId, "Failed to connect to proxy server")
continue continue
} }
go util.Pipe(localConn, newStream) go util.Pipe(localConn, newStream)

View File

@ -218,7 +218,7 @@ var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
var ErrNotCloak = errors.New("TLS but non-Cloak ClientHello") var ErrNotCloak = errors.New("TLS but non-Cloak ClientHello")
var ErrBadProxyMethod = errors.New("invalid proxy method") var ErrBadProxyMethod = errors.New("invalid proxy method")
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (UID []byte, sessionID uint32, proxyMethod string, encryptionMethod byte, finisher func([]byte) error, err error) { func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info *ClientInfo, finisher func([]byte) error, err error) {
ch, err := parseClientHello(firstPacket) ch, err := parseClientHello(firstPacket)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
@ -227,13 +227,13 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (UID []byt
} }
var sharedSecret []byte var sharedSecret []byte
UID, sessionID, proxyMethod, encryptionMethod, sharedSecret, err = TouchStone(ch, sta) info, sharedSecret, err = TouchStone(ch, sta)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
err = ErrNotCloak err = ErrNotCloak
return return
} }
if _, ok := sta.ProxyBook[proxyMethod]; !ok { if _, ok := sta.ProxyBook[info.ProxyMethod]; !ok {
err = ErrBadProxyMethod err = ErrBadProxyMethod
return return
} }

View File

@ -10,12 +10,19 @@ import (
"time" "time"
) )
type ClientInfo struct {
UID []byte
SessionId uint32
ProxyMethod string
EncryptionMethod byte
}
var ErrReplay = errors.New("duplicate random") var ErrReplay = errors.New("duplicate random")
var ErrInvalidPubKey = errors.New("public key has invalid format") var ErrInvalidPubKey = errors.New("public key has invalid format")
var ErrCiphertextLength = errors.New("ciphertext has the wrong length") var ErrCiphertextLength = errors.New("ciphertext has the wrong length")
var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window") var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window")
func TouchStone(ch *ClientHello, sta *State) (UID []byte, sessionID uint32, proxyMethod string, encryptionMethod byte, sharedSecret []byte, err error) { func TouchStone(ch *ClientHello, sta *State) (info *ClientInfo, sharedSecret []byte, err error) {
if sta.registerRandom(ch.random) { if sta.registerRandom(ch.random) {
err = ErrReplay err = ErrReplay
@ -47,9 +54,13 @@ func TouchStone(ch *ClientHello, sta *State) (UID []byte, sessionID uint32, prox
return return
} }
UID = plaintext[0:16] info = &ClientInfo{
proxyMethod = string(bytes.Trim(plaintext[16:28], "\x00")) UID: plaintext[0:16],
encryptionMethod = plaintext[28] SessionId: 0,
ProxyMethod: string(bytes.Trim(plaintext[16:28], "\x00")),
EncryptionMethod: plaintext[28],
}
timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37])) timestamp := int64(binary.BigEndian.Uint64(plaintext[29:37]))
clientTime := time.Unix(timestamp, 0) clientTime := time.Unix(timestamp, 0)
serverTime := sta.Now() serverTime := sta.Now()
@ -57,6 +68,6 @@ func TouchStone(ch *ClientHello, sta *State) (UID []byte, sessionID uint32, prox
err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp) err = fmt.Errorf("%v: received timestamp %v", ErrTimestampOutOfWindow, timestamp)
return return
} }
sessionID = binary.BigEndian.Uint32(plaintext[37:41]) info.SessionId = binary.BigEndian.Uint32(plaintext[37:41])
return return
} }