diff --git a/internal/server/activeuser.go b/internal/server/activeuser.go index 4d0b601..48d6ebe 100644 --- a/internal/server/activeuser.go +++ b/internal/server/activeuser.go @@ -1,6 +1,7 @@ package server import ( + "github.com/cbeuw/Cloak/internal/server/usermanager" "sync" mux "github.com/cbeuw/Cloak/internal/multiplex" @@ -40,7 +41,8 @@ func (u *ActiveUser) GetSession(sessionID uint32, config *mux.SessionConfig) (se return sesh, true, nil } else { if !u.bypass { - err := u.panel.Manager.AuthoriseNewSession(u.arrUID[:], len(u.sessions)) + ainfo := usermanager.AuthorisationInfo{NumExistingSessions: len(u.sessions)} + err := u.panel.Manager.AuthoriseNewSession(u.arrUID[:], ainfo) if err != nil { return nil, false, err } diff --git a/internal/server/usermanager/localmanager.go b/internal/server/usermanager/localmanager.go index 99ba3a1..1814dd0 100644 --- a/internal/server/usermanager/localmanager.go +++ b/internal/server/usermanager/localmanager.go @@ -93,7 +93,7 @@ func (manager *localManager) AuthenticateUser(UID []byte) (int64, int64, error) return upRate, downRate, nil } -func (manager *localManager) AuthoriseNewSession(UID []byte, numExistingSessions int) error { +func (manager *localManager) AuthoriseNewSession(UID []byte, ainfo AuthorisationInfo) error { var arrUID [16]byte copy(arrUID[:], UID) var sessionsCap int @@ -121,12 +121,10 @@ func (manager *localManager) AuthoriseNewSession(UID []byte, numExistingSessions if expiryTime < time.Now().Unix() { return ErrUserExpired } - //user.sessionsM.RLock() - if numExistingSessions >= sessionsCap { - //user.sessionsM.RUnlock() + + if ainfo.NumExistingSessions >= sessionsCap { return ErrSessionsCapReached } - //user.sessionsM.RUnlock() return nil } diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index e127aa8..99ac8d9 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -20,6 +20,10 @@ type StatusResponse struct { Message string } +type AuthorisationInfo struct { + NumExistingSessions int +} + const ( TERMINATE = iota + 1 ) @@ -33,6 +37,6 @@ var ErrUserExpired = errors.New("User has expired") type UserManager interface { AuthenticateUser([]byte) (int64, int64, error) - AuthoriseNewSession([]byte, int) error + AuthoriseNewSession([]byte, AuthorisationInfo) error UploadStatus([]StatusUpdate) ([]StatusResponse, error) }