diff --git a/internal/server/auth.go b/internal/server/auth.go index aa13f47..40eae1b 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -15,20 +15,13 @@ var ErrCiphertextLength = errors.New("ciphertext has the wrong length") var ErrTimestampOutOfWindow = errors.New("timestamp is outside of the accepting window") func TouchStone(ch *ClientHello, sta *State) (UID []byte, sessionID uint32, proxyMethod string, encryptionMethod byte, sharedSecret []byte, err error) { - var random [32]byte - copy(random[:], ch.random) - sta.usedRandomM.Lock() - used := sta.usedRandom[random] - sta.usedRandom[random] = int(sta.Now().Unix()) - sta.usedRandomM.Unlock() - - if used != 0 { + if sta.registerRandom(ch.random) { err = ErrReplay return } - ephPub, ok := ecdh.Unmarshal(random[:]) + ephPub, ok := ecdh.Unmarshal(ch.random) if !ok { err = ErrInvalidPubKey return @@ -48,7 +41,7 @@ func TouchStone(ch *ClientHello, sta *State) (UID []byte, sessionID uint32, prox } var plaintext []byte - plaintext, err = util.AESGCMDecrypt(random[0:12], sharedSecret, ciphertext) + plaintext, err = util.AESGCMDecrypt(ch.random[0:12], sharedSecret, ciphertext) if err != nil { return } diff --git a/internal/server/state.go b/internal/server/state.go index db03005..cd50e11 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -36,7 +36,7 @@ type State struct { RedirAddr string usedRandomM sync.RWMutex - usedRandom map[[32]byte]int + usedRandom map[[32]byte]int64 Panel *userPanel LocalAPIRouter *gmux.Router @@ -48,7 +48,7 @@ func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, err BindPort: bindPort, Now: nowFunc, } - ret.usedRandom = make(map[[32]byte]int) + ret.usedRandom = make(map[[32]byte]int64) go ret.UsedRandomCleaner() return ret, nil } @@ -112,13 +112,23 @@ const TIMESTAMP_WINDOW = 12 * time.Hour func (sta *State) UsedRandomCleaner() { for { time.Sleep(TIMESTAMP_WINDOW) - now := int(sta.Now().Unix()) + now := sta.Now().Unix() sta.usedRandomM.Lock() for key, t := range sta.usedRandom { - if now-t > int(TIMESTAMP_WINDOW.Seconds()) { + if now-t > int64(TIMESTAMP_WINDOW.Seconds()) { delete(sta.usedRandom, key) } } sta.usedRandomM.Unlock() } } + +func (sta *State) registerRandom(r []byte) bool { + var random [32]byte + copy(random[:], r) + sta.usedRandomM.Lock() + _, used := sta.usedRandom[random] + sta.usedRandom[random] = sta.Now().Unix() + sta.usedRandomM.Unlock() + return used +}