diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 9cbcfe0..52c568d 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -125,7 +125,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } sesh, existing, err := user.GetSession(ci.SessionId, seshConfig) if err != nil { - user.DeleteSession(ci.SessionId, "") + user.CloseSession(ci.SessionId, "") log.Error(err) return } @@ -163,7 +163,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { "sessionID": ci.SessionId, "reason": sesh.TerminalMsg(), }).Info("Session closed") - user.DeleteSession(ci.SessionId, "") + user.CloseSession(ci.SessionId, "") return } else { continue @@ -173,7 +173,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) { localConn, err := net.Dial(proxyAddr.Network(), proxyAddr.String()) if err != nil { log.Errorf("Failed to connect to %v: %v", ci.ProxyMethod, err) - user.DeleteSession(ci.SessionId, "Failed to connect to proxy server") + user.CloseSession(ci.SessionId, "Failed to connect to proxy server") continue } log.Tracef("%v endpoint has been successfully connected", ci.ProxyMethod) diff --git a/internal/server/activeuser.go b/internal/server/activeuser.go index 280b32b..11a5bd1 100644 --- a/internal/server/activeuser.go +++ b/internal/server/activeuser.go @@ -20,8 +20,8 @@ type ActiveUser struct { sessions map[uint32]*mux.Session } -// DeleteSession closes a session and removes its reference from the user -func (u *ActiveUser) DeleteSession(sessionID uint32, reason string) { +// CloseSession closes a session and removes its reference from the user +func (u *ActiveUser) CloseSession(sessionID uint32, reason string) { u.sessionsM.Lock() sesh, existing := u.sessions[sessionID] if existing { @@ -29,10 +29,11 @@ func (u *ActiveUser) DeleteSession(sessionID uint32, reason string) { sesh.SetTerminalMsg(reason) sesh.Close() } - if len(u.sessions) == 0 { - u.panel.DeleteActiveUser(u) - } + remaining := len(u.sessions) u.sessionsM.Unlock() + if remaining == 0 { + u.panel.TerminateActiveUser(u, "no session left") + } } // GetSession returns the reference to an existing session, or if one such session doesn't exist, it queries @@ -58,17 +59,15 @@ func (u *ActiveUser) GetSession(sessionID uint32, config *mux.SessionConfig) (se } } -// Terminate closes all sessions of this active user -func (u *ActiveUser) Terminate(reason string) { +// closeAllSessions closes all sessions of this active user +func (u *ActiveUser) closeAllSessions(reason string) { u.sessionsM.Lock() - for _, sesh := range u.sessions { - if reason != "" { - sesh.SetTerminalMsg(reason) - } + for sessionID, sesh := range u.sessions { + sesh.SetTerminalMsg(reason) sesh.Close() + delete(u.sessions, sessionID) } u.sessionsM.Unlock() - u.panel.DeleteActiveUser(u) } // NumSession returns the number of active sessions diff --git a/internal/server/activeuser_test.go b/internal/server/activeuser_test.go index b46c3cc..2298da8 100644 --- a/internal/server/activeuser_test.go +++ b/internal/server/activeuser_test.go @@ -64,7 +64,7 @@ func TestActiveUser_Bypass(t *testing.T) { } }) t.Run("delete a session", func(t *testing.T) { - user.DeleteSession(0, "") + user.CloseSession(0, "") if user.NumSession() != 1 { t.Error("number of session is not 1 after deleting one") } @@ -72,11 +72,8 @@ func TestActiveUser_Bypass(t *testing.T) { t.Error("session not closed after deletion") } }) - t.Run("terminating user", func(t *testing.T) { - user.Terminate("") - if panel.isActive(user.arrUID[:]) { - t.Error("user is still active after termination") - } + t.Run("close all sessions", func(t *testing.T) { + user.closeAllSessions("") if !sesh1.IsClosed() { t.Error("session not closed after user termination") } @@ -97,7 +94,7 @@ func TestActiveUser_Bypass(t *testing.T) { } }) t.Run("delete last session", func(t *testing.T) { - user.DeleteSession(0, "") + user.CloseSession(0, "") if panel.isActive(user.arrUID[:]) { t.Error("user still active after last session deleted") } diff --git a/internal/server/userpanel.go b/internal/server/userpanel.go index 2ba2d75..e1761ad 100644 --- a/internal/server/userpanel.go +++ b/internal/server/userpanel.go @@ -79,10 +79,10 @@ func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) { return user, nil } -// DeleteActiveUser deletes the references to the active user -func (panel *userPanel) DeleteActiveUser(user *ActiveUser) { - // TODO: terminate the user here? +// TerminateActiveUser terminates a user and deletes its references +func (panel *userPanel) TerminateActiveUser(user *ActiveUser, reason string) { panel.updateUsageQueueForOne(user) + user.closeAllSessions(reason) panel.activeUsersM.Lock() delete(panel.activeUsers, user.arrUID) panel.activeUsersM.Unlock() @@ -182,7 +182,7 @@ func (panel *userPanel) commitUpdate() error { user := panel.activeUsers[arrUID] panel.activeUsersM.RUnlock() if user != nil { - user.Terminate(resp.Message) + panel.TerminateActiveUser(user, resp.Message) } } } diff --git a/internal/server/userpanel_test.go b/internal/server/userpanel_test.go index be10ddb..4e13399 100644 --- a/internal/server/userpanel_test.go +++ b/internal/server/userpanel_test.go @@ -43,14 +43,14 @@ func TestUserPanel_BypassUser(t *testing.T) { t.Error("commit returned", err) } }) - t.Run("DeleteActiveUser", func(t *testing.T) { - panel.DeleteActiveUser(user) + t.Run("TerminateActiveUser", func(t *testing.T) { + panel.TerminateActiveUser(user, "") if panel.isActive(user.arrUID[:]) { t.Error("user still active after deletion", err) } }) t.Run("Repeated delete", func(t *testing.T) { - panel.DeleteActiveUser(user) + panel.TerminateActiveUser(user, "") }) err = manager.Close() if err != nil {