From f27889af11bfecaf70459aae77111b038fc6cb29 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Tue, 5 Jan 2021 21:48:02 +0000 Subject: [PATCH] Allow partial json to be POSTed to admin/user/{UID} for only updating select fields --- internal/server/usermanager/api.yaml | 8 +- .../server/usermanager/api_router_test.go | 30 +++++++ internal/server/usermanager/localmanager.go | 65 ++++++++------ .../server/usermanager/localmanager_test.go | 84 +++++++++++-------- internal/server/usermanager/usermanager.go | 19 +++-- internal/server/userpanel_test.go | 24 +++--- 6 files changed, 142 insertions(+), 88 deletions(-) diff --git a/internal/server/usermanager/api.yaml b/internal/server/usermanager/api.yaml index 39ec6d7..d49875d 100644 --- a/internal/server/usermanager/api.yaml +++ b/internal/server/usermanager/api.yaml @@ -2,7 +2,7 @@ swagger: '2.0' info: description: | This is the API of Cloak server - version: 1.0.0 + version: 0.0.2 title: Cloak Server contact: email: cbeuw.andy@gmail.com @@ -12,8 +12,6 @@ info: # host: petstore.swagger.io # basePath: /v2 tags: - - name: admin - description: Endpoints used by the host administrators - name: users description: Operations related to user controls by admin # schemes: @@ -22,7 +20,6 @@ paths: /admin/users: get: tags: - - admin - users summary: Show all users description: Returns an array of all UserInfo @@ -41,7 +38,6 @@ paths: /admin/users/{UID}: get: tags: - - admin - users summary: Show userinfo by UID description: Returns a UserInfo object @@ -68,7 +64,6 @@ paths: description: internal error post: tags: - - admin - users summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created operationId: writeUserInfo @@ -100,7 +95,6 @@ paths: description: internal error delete: tags: - - admin - users summary: Deletes a user operationId: deleteUser diff --git a/internal/server/usermanager/api_router_test.go b/internal/server/usermanager/api_router_test.go index 1310f60..9f957c8 100644 --- a/internal/server/usermanager/api_router_test.go +++ b/internal/server/usermanager/api_router_test.go @@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) { assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body) }) + t.Run("partial update", func(t *testing.T) { + req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled)) + assert.NoError(t, err) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + + partialUserInfo := UserInfo{ + UID: mockUID, + SessionsCap: JustInt32(10), + } + partialMarshalled, _ := json.Marshal(partialUserInfo) + req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled)) + assert.NoError(t, err) + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + + req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil) + assert.NoError(t, err) + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + var got UserInfo + err = json.Unmarshal(rr.Body.Bytes(), &got) + assert.NoError(t, err) + + expected := mockUserInfo + expected.SessionsCap = partialUserInfo.SessionsCap + assert.EqualValues(t, expected, got) + }) + t.Run("empty parameter", func(t *testing.T) { req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled)) if err != nil { diff --git a/internal/server/usermanager/localmanager.go b/internal/server/usermanager/localmanager.go index 1689595..d60f62e 100644 --- a/internal/server/usermanager/localmanager.go +++ b/internal/server/usermanager/localmanager.go @@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo "User no longer exists", } responses = append(responses, resp) + continue } oldUp := int64(u64(bucket.Get([]byte("UpCredit")))) @@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) { err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { var uinfo UserInfo uinfo.UID = UID - uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) + uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap"))))) + uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate"))))) + uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate"))))) + uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit"))))) + uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit"))))) + uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime"))))) infos = append(infos, uinfo) return nil }) @@ -200,40 +201,52 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) return ErrUserNotFound } uinfo.UID = UID - uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) + uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap"))))) + uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate"))))) + uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate"))))) + uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit"))))) + uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit"))))) + uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime"))))) return nil }) return } -func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) { +func (manager *localManager) WriteUserInfo(u UserInfo) (err error) { err = manager.db.Update(func(tx *bolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) + bucket, err := tx.CreateBucketIfNotExists(u.UID) if err != nil { return err } - if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { - return err + if u.SessionsCap != nil { + if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil { + return err + } } - if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { - return err + if u.UpRate != nil { + if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil { + return err + } } - if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { - return err + if u.DownRate != nil { + if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil { + return err + } } - if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { - return err + if u.UpCredit != nil { + if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil { + return err + } } - if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { - return err + if u.DownCredit != nil { + if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil { + return err + } } - if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { - return err + if u.ExpiryTime != nil { + if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil { + return err + } } return nil }) diff --git a/internal/server/usermanager/localmanager_test.go b/internal/server/usermanager/localmanager_test.go index 9e9370b..40873cc 100644 --- a/internal/server/usermanager/localmanager_test.go +++ b/internal/server/usermanager/localmanager_test.go @@ -3,6 +3,7 @@ package usermanager import ( "encoding/binary" "github.com/cbeuw/Cloak/internal/common" + "github.com/stretchr/testify/assert" "io/ioutil" "math/rand" "os" @@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockUserInfo = UserInfo{ UID: mockUID, - SessionsCap: 0, - UpRate: 0, - DownRate: 0, - UpCredit: 0, - DownCredit: 0, - ExpiryTime: 100, + SessionsCap: JustInt32(10), + UpRate: JustInt64(100), + DownRate: JustInt64(1000), + UpCredit: JustInt64(10000), + DownCredit: JustInt64(100000), + ExpiryTime: JustInt64(1000000), } func makeManager(t *testing.T) (mgr *localManager, cleaner func()) { @@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) { if err != nil { t.Error(err) } + + got, err := mgr.GetUserInfo(mockUID) + assert.NoError(t, err) + assert.EqualValues(t, mockUserInfo, got) + + /* Partial update */ + err = mgr.WriteUserInfo(UserInfo{ + UID: mockUID, + SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1), + }) + assert.NoError(t, err) + + expected := mockUserInfo + expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1) + got, err = mgr.GetUserInfo(mockUID) + assert.NoError(t, err) + assert.EqualValues(t, expected, got) } func TestLocalManager_GetUserInfo(t *testing.T) { @@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) { t.Run("update a field", func(t *testing.T) { _ = mgr.WriteUserInfo(mockUserInfo) updatedUserInfo := mockUserInfo - updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1 + updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1) err := mgr.WriteUserInfo(updatedUserInfo) if err != nil { @@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) { } } -var validUserInfo = UserInfo{ - UID: mockUID, - SessionsCap: 10, - UpRate: 100, - DownRate: 1000, - UpCredit: 10000, - DownCredit: 100000, - ExpiryTime: 1000000, -} +var validUserInfo = mockUserInfo func TestLocalManager_AuthenticateUser(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") @@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Error(err) } - if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate { + if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate { t.Error("wrong up or down rate") } }) @@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Run("expired user", func(t *testing.T) { expiredUserInfo := validUserInfo - expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix()) _ = mgr.WriteUserInfo(expiredUserInfo) @@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Run("no credit", func(t *testing.T) { creditlessUserInfo := validUserInfo - creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1 + creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1) _ = mgr.WriteUserInfo(creditlessUserInfo) @@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) { t.Run("expired user", func(t *testing.T) { expiredUserInfo := validUserInfo - expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix()) _ = mgr.WriteUserInfo(expiredUserInfo) err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) @@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) { t.Run("too many sessions", func(t *testing.T) { _ = mgr.WriteUserInfo(validUserInfo) - err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)}) + err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)}) if err != ErrSessionsCapReached { t.Error("session cap not reached") } @@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) { t.Error(err) } - if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage { + if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage { t.Error("up usage incorrect") } - if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage { + if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage { t.Error("down usage incorrect") } }) @@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) { UID: validUserInfo.UID, Active: true, NumSession: 1, - UpUsage: validUserInfo.UpCredit + 100, + UpUsage: *validUserInfo.UpCredit + 100, DownUsage: 0, Timestamp: mockWorldState.Now().Unix(), }, @@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) { Active: true, NumSession: 1, UpUsage: 0, - DownUsage: validUserInfo.DownCredit + 100, + DownUsage: *validUserInfo.DownCredit + 100, Timestamp: mockWorldState.Now().Unix(), }, }, {"expired", UserInfo{ UID: mockUID, - SessionsCap: 10, - UpRate: 0, - DownRate: 0, - UpCredit: 0, - DownCredit: 0, - ExpiryTime: -1, + SessionsCap: JustInt32(10), + UpRate: JustInt64(0), + DownRate: JustInt64(0), + UpCredit: JustInt64(0), + DownCredit: JustInt64(0), + ExpiryTime: JustInt64(-1), }, StatusUpdate{ UID: mockUserInfo.UID, @@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) { rand.Read(randUID) newUser := UserInfo{ UID: randUID, - SessionsCap: rand.Int31(), - UpRate: rand.Int63(), - DownRate: rand.Int63(), - UpCredit: rand.Int63(), - DownCredit: rand.Int63(), - ExpiryTime: rand.Int63(), + SessionsCap: JustInt32(rand.Int31()), + UpRate: JustInt64(rand.Int63()), + DownRate: JustInt64(rand.Int63()), + UpCredit: JustInt64(rand.Int63()), + DownCredit: JustInt64(rand.Int63()), + ExpiryTime: JustInt64(rand.Int63()), } users = append(users, newUser) wg.Add(1) diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index 7bf84d5..bb5456e 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -14,16 +14,23 @@ type StatusUpdate struct { Timestamp int64 } +type MaybeInt32 *int32 +type MaybeInt64 *int64 + type UserInfo struct { UID []byte - SessionsCap int32 - UpRate int64 - DownRate int64 - UpCredit int64 - DownCredit int64 - ExpiryTime int64 + SessionsCap MaybeInt32 + UpRate MaybeInt64 + DownRate MaybeInt64 + UpCredit MaybeInt64 + DownCredit MaybeInt64 + ExpiryTime MaybeInt64 } +func JustInt32(v int32) MaybeInt32 { return &v } + +func JustInt64(v int64) MaybeInt64 { return &v } + type StatusResponse struct { UID []byte Action int diff --git a/internal/server/userpanel_test.go b/internal/server/userpanel_test.go index f74d3e9..b28744c 100644 --- a/internal/server/userpanel_test.go +++ b/internal/server/userpanel_test.go @@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var validUserInfo = usermanager.UserInfo{ UID: mockUID, - SessionsCap: 10, - UpRate: 100, - DownRate: 1000, - UpCredit: 10000, - DownCredit: 100000, - ExpiryTime: 1000000, + SessionsCap: usermanager.JustInt32(10), + UpRate: usermanager.JustInt64(100), + DownRate: usermanager.JustInt64(1000), + UpCredit: usermanager.JustInt64(10000), + DownCredit: usermanager.JustInt64(100000), + ExpiryTime: usermanager.JustInt64(1000000), } func TestUserPanel_GetUser(t *testing.T) { @@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 { + if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 { t.Error("down credit incorrect update") } - if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 { + if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 { t.Error("up credit incorrect update") } @@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) { + if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) { t.Error("down credit incorrect update") } - if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) { + if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) { t.Error("up credit incorrect update") } }) @@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { t.Error(err) } - user.valve.AddTx(validUserInfo.DownCredit + 100) + user.valve.AddTx(*validUserInfo.DownCredit + 100) panel.updateUsageQueue() err = panel.commitUpdate() if err != nil { @@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != -100 { + if *updatedUinfo.DownCredit != -100 { t.Error("down credit not updated correctly after the user has been terminated") } })