From 4b6ab1b4d5960a62d376fb1b09e7513851ec97aa Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Wed, 26 Dec 2018 00:46:39 +0000 Subject: [PATCH] Finish the admin control --- cmd/ck-client/admin.go | 133 ++++++++++++++++++++-- cmd/ck-client/ck-client.go | 26 ----- internal/multiplex/switchboard.go | 1 + internal/server/usermanager/controller.go | 121 +++++++++++++++----- internal/server/usermanager/userpanel.go | 47 ++++++++ 5 files changed, 259 insertions(+), 69 deletions(-) diff --git a/cmd/ck-client/admin.go b/cmd/ck-client/admin.go index 9b00a52..d931e5e 100644 --- a/cmd/ck-client/admin.go +++ b/cmd/ck-client/admin.go @@ -15,6 +15,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net" "github.com/cbeuw/Cloak/internal/client" @@ -38,6 +39,47 @@ type administrator struct { adminUID []byte } +func adminPrompt(sta *client.State) error { + a, err := adminHandshake(sta) + if err != nil { + log.Println(err) + return err + } + fmt.Println(`1 listActiveUsers none []uids +2 listAllUsers none []userinfo +3 getUserInfo uid userinfo +4 addNewUser userinfo ok +5 delUser uid ok +6 syncMemFromDB uid ok + +7 setSessionsCap uid cap ok +8 setUpRate uid rate ok +9 setDownRate uid rate ok +10 setUpCredit uid credit ok +11 setDownCredit uid credit ok +12 setExpiryTime uid time ok +13 addUpCredit uid delta ok +14 addDownCredit uid delta ok`) + buf := make([]byte, 16000) + for { + req, err := a.getRequest() + if err != nil { + log.Println(err) + continue + } + a.adminConn.Write(req) + n, err := a.adminConn.Read(buf) + if err != nil { + return err + } + resp, err := a.checkAndDecrypt(buf[:n]) + if err != nil { + return err + } + fmt.Println(string(resp)) + } +} + func adminHandshake(sta *client.State) (*administrator, error) { fmt.Println("Enter the ip:port of your server") var addr string @@ -76,27 +118,42 @@ func adminHandshake(sta *client.State) (*administrator, error) { } func (a *administrator) getRequest() (req []byte, err error) { + promptUID := func() []byte { + fmt.Println("Enter UID") + var b64UID string + fmt.Scanln(&b64UID) + ret, _ := base64.StdEncoding.DecodeString(b64UID) + return ret + } + + promptInt64 := func(name string) []byte { + fmt.Println("Enter New " + name) + var val int64 + fmt.Scanln(&val) + ret := make([]byte, 8) + binary.BigEndian.PutUint64(ret, uint64(val)) + return ret + } + promptUint32 := func(name string) []byte { + fmt.Println("Enter New " + name) + var val uint32 + fmt.Scanln(&val) + ret := make([]byte, 4) + binary.BigEndian.PutUint32(ret, val) + return ret + } + fmt.Println("Select your command") - fmt.Println(`1 listActiveUsers none []uids -2 listAllUsers none []userinfo -3 getUserInfo uid userinfo -4 addNewUser userinfo ok`) var cmd string fmt.Scanln(&cmd) switch cmd { case "1": req = a.request([]byte{0x01}) - return case "2": req = a.request([]byte{0x02}) - return case "3": - fmt.Println("Enter UID") - var b64UID string - fmt.Scanln(&b64UID) - UID, _ := base64.StdEncoding.DecodeString(b64UID) + UID := promptUID() req = a.request(append([]byte{0x03}, UID...)) - return case "4": var uinfo UserInfo var b64UID string @@ -118,10 +175,62 @@ func (a *administrator) getRequest() (req []byte, err error) { fmt.Scanf("%d", &uinfo.ExpiryTime) marshed, _ := json.Marshal(uinfo) req = a.request(append([]byte{0x04}, marshed...)) - return + case "5": + UID := promptUID() + fmt.Println("Are you sure to delete this user? y/n") + var ans string + fmt.Scanln(&ans) + if ans != "y" && ans != "Y" { + return + } + req = a.request(append([]byte{0x05}, UID...)) + case "6": + UID := promptUID() + req = a.request(append([]byte{0x06}, UID...)) + case "7": + arg := make([]byte, 36) + copy(arg, promptUID()) + copy(arg[32:], promptUint32("SessionsCap")) + req = a.request(append([]byte{0x07}, arg...)) + case "8": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("UpRate")) + req = a.request(append([]byte{0x08}, arg...)) + case "9": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("DownRate")) + req = a.request(append([]byte{0x09}, arg...)) + case "10": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("UpCredit")) + req = a.request(append([]byte{0x0a}, arg...)) + case "11": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("DownCredit")) + req = a.request(append([]byte{0x0b}, arg...)) + case "12": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("ExpiryTime")) + req = a.request(append([]byte{0x0c}, arg...)) + case "13": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("UpCredit to add")) + req = a.request(append([]byte{0x0d}, arg...)) + case "14": + arg := make([]byte, 40) + copy(arg, promptUID()) + copy(arg[32:], promptInt64("DownCredit to add")) + req = a.request(append([]byte{0x0e}, arg...)) default: return nil, errors.New("Unreconised cmd") } + return req, nil } // protocol: 0[TLS record layer 5 bytes]5[IV 16 bytes]21[data][hmac 32 bytes] diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 2f86446..3c33311 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -81,32 +81,6 @@ func makeRemoteConn(sta *client.State) (net.Conn, error) { } -func adminPrompt(sta *client.State) error { - a, err := adminHandshake(sta) - if err != nil { - return err - } - log.Println(err) - buf := make([]byte, 16000) - for { - req, err := a.getRequest() - if err != nil { - log.Println(err) - continue - } - a.adminConn.Write(req) - n, err := a.adminConn.Read(buf) - if err != nil { - return err - } - resp, err := a.checkAndDecrypt(buf[:n]) - if err != nil { - return err - } - fmt.Println(string(resp)) - } -} - func main() { // Should be 127.0.0.1 to listen to ss-local on this machine var localHost string diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 4c30434..131cb51 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -76,6 +76,7 @@ func (sb *switchboard) send(data []byte) (int, error) { if err != nil { return n, err } + sb.txWait(n) if sb.AddTxCredit(-int64(n)) < 0 { log.Println(ErrNoTxCredit) defer sb.session.Close() diff --git a/internal/server/usermanager/controller.go b/internal/server/usermanager/controller.go index 20bdef0..5d6b1b1 100644 --- a/internal/server/usermanager/controller.go +++ b/internal/server/usermanager/controller.go @@ -29,7 +29,7 @@ import ( 10 setUpCredit uid credit ok 11 setDownCredit uid credit ok 12 setExpiryTime uid time ok -13 addUpcredit uid delta ok +13 addUpCredit uid delta ok 14 addDownCredit uid delta ok */ @@ -42,7 +42,16 @@ func (up *Userpanel) MakeController(adminUID []byte) *controller { return &controller{up, adminUID} } -func (c *controller) HandleRequest(req []byte) ([]byte, error) { +var errInvalidArgument = errors.New("Invalid argument format") + +func (c *controller) HandleRequest(req []byte) (resp []byte, err error) { + check := func(err error) []byte { + if err != nil { + return c.respond([]byte(err.Error())) + } else { + return c.respond([]byte("ok")) + } + } plain, err := c.checkAndDecrypt(req) if err == ErrInvalidMac { log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\n raw request:\n%x\ndecrypted msg:\n%x", req, plain) @@ -52,55 +61,105 @@ func (c *controller) HandleRequest(req []byte) ([]byte, error) { return c.respond([]byte(err.Error())), nil } - switch plain[0] { + typ := plain[0] + var arg []byte + if len(plain) > 1 { + arg = plain[1:] + } + switch typ { case 1: UIDs := c.listActiveUsers() - resp, _ := json.Marshal(UIDs) - return c.respond(resp), nil + resp, _ = json.Marshal(UIDs) + resp = c.respond(resp) case 2: uinfos := c.listAllUsers() - resp, _ := json.Marshal(uinfos) - return c.respond(resp), nil + resp, _ = json.Marshal(uinfos) + resp = c.respond(resp) case 3: - uinfo, err := c.getUserInfo(plain[1:33]) + uinfo, err := c.getUserInfo(arg) if err != nil { - return c.respond([]byte(err.Error())), nil + resp = c.respond([]byte(err.Error())) + break } - resp, _ := json.Marshal(uinfo) - return c.respond(resp), nil + resp, _ = json.Marshal(uinfo) + resp = c.respond(resp) case 4: var uinfo UserInfo - err = json.Unmarshal(plain[1:], &uinfo) + err = json.Unmarshal(arg, &uinfo) if err != nil { - return c.respond([]byte(err.Error())), nil + resp = c.respond([]byte(err.Error())) + break } err = c.addNewUser(uinfo) - if err != nil { - return c.respond([]byte(err.Error())), nil - } else { - return c.respond([]byte("ok")), nil - } + resp = check(err) case 5: - err = c.delUser(plain[1:]) - if err != nil { - return c.respond([]byte(err.Error())), nil - } else { - return c.respond([]byte("ok")), nil - } - + err = c.delUser(arg) + resp = check(err) case 6: - err = c.syncMemFromDB(plain[1:33]) - if err != nil { - return c.respond([]byte(err.Error())), nil - } else { - return c.respond([]byte("ok")), nil + err = c.syncMemFromDB(arg) + resp = check(err) + case 7: + if len(arg) < 36 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break } - // TODO: implement the rest + err = c.setSessionsCap(arg[0:32], Uint32(arg[32:36])) + resp = check(err) + case 8: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.setUpRate(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 9: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.setDownRate(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 10: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.setUpCredit(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 11: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.setDownCredit(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 12: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.setExpiryTime(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 13: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.addUpCredit(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) + case 14: + if len(arg) < 40 { + resp = c.respond([]byte(errInvalidArgument.Error())) + break + } + err = c.addDownCredit(arg[0:32], int64(Uint64(arg[32:40]))) + resp = check(err) default: return c.respond([]byte("Unsupported action")), nil } + return } diff --git a/internal/server/usermanager/userpanel.go b/internal/server/usermanager/userpanel.go index d1b774a..37bd315 100644 --- a/internal/server/usermanager/userpanel.go +++ b/internal/server/usermanager/userpanel.go @@ -416,3 +416,50 @@ func (up *Userpanel) setExpiryTime(UID []byte, time int64) error { u.setExpiryTime(time) return nil } + +func (up *Userpanel) addUpCredit(UID []byte, delta int64) error { + err := up.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(UID) + if b == nil { + return ErrUserNotFound + } + old := b.Get([]byte("UpCredit")) + new := int64(Uint64(old)) + delta + if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + u := up.getActiveUser(UID) + if u == nil { + return nil + } + u.addUpCredit(delta) + return nil +} +func (up *Userpanel) addDownCredit(UID []byte, delta int64) error { + err := up.db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(UID) + if b == nil { + return ErrUserNotFound + } + old := b.Get([]byte("DownCredit")) + new := int64(Uint64(old)) + delta + if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + u := up.getActiveUser(UID) + if u == nil { + return nil + } + u.addDownCredit(delta) + return nil +}