Finish the admin control

This commit is contained in:
Qian Wang 2018-12-26 00:46:39 +00:00
parent 73aefdeeeb
commit 4b6ab1b4d5
5 changed files with 259 additions and 69 deletions

View File

@ -15,6 +15,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net" "net"
"github.com/cbeuw/Cloak/internal/client" "github.com/cbeuw/Cloak/internal/client"
@ -38,6 +39,47 @@ type administrator struct {
adminUID []byte adminUID []byte
} }
func adminPrompt(sta *client.State) error {
a, err := adminHandshake(sta)
if err != nil {
log.Println(err)
return err
}
fmt.Println(`1 listActiveUsers none []uids
2 listAllUsers none []userinfo
3 getUserInfo uid userinfo
4 addNewUser userinfo ok
5 delUser uid ok
6 syncMemFromDB uid ok
7 setSessionsCap uid cap ok
8 setUpRate uid rate ok
9 setDownRate uid rate ok
10 setUpCredit uid credit ok
11 setDownCredit uid credit ok
12 setExpiryTime uid time ok
13 addUpCredit uid delta ok
14 addDownCredit uid delta ok`)
buf := make([]byte, 16000)
for {
req, err := a.getRequest()
if err != nil {
log.Println(err)
continue
}
a.adminConn.Write(req)
n, err := a.adminConn.Read(buf)
if err != nil {
return err
}
resp, err := a.checkAndDecrypt(buf[:n])
if err != nil {
return err
}
fmt.Println(string(resp))
}
}
func adminHandshake(sta *client.State) (*administrator, error) { func adminHandshake(sta *client.State) (*administrator, error) {
fmt.Println("Enter the ip:port of your server") fmt.Println("Enter the ip:port of your server")
var addr string var addr string
@ -76,27 +118,42 @@ func adminHandshake(sta *client.State) (*administrator, error) {
} }
func (a *administrator) getRequest() (req []byte, err error) { func (a *administrator) getRequest() (req []byte, err error) {
promptUID := func() []byte {
fmt.Println("Enter UID")
var b64UID string
fmt.Scanln(&b64UID)
ret, _ := base64.StdEncoding.DecodeString(b64UID)
return ret
}
promptInt64 := func(name string) []byte {
fmt.Println("Enter New " + name)
var val int64
fmt.Scanln(&val)
ret := make([]byte, 8)
binary.BigEndian.PutUint64(ret, uint64(val))
return ret
}
promptUint32 := func(name string) []byte {
fmt.Println("Enter New " + name)
var val uint32
fmt.Scanln(&val)
ret := make([]byte, 4)
binary.BigEndian.PutUint32(ret, val)
return ret
}
fmt.Println("Select your command") fmt.Println("Select your command")
fmt.Println(`1 listActiveUsers none []uids
2 listAllUsers none []userinfo
3 getUserInfo uid userinfo
4 addNewUser userinfo ok`)
var cmd string var cmd string
fmt.Scanln(&cmd) fmt.Scanln(&cmd)
switch cmd { switch cmd {
case "1": case "1":
req = a.request([]byte{0x01}) req = a.request([]byte{0x01})
return
case "2": case "2":
req = a.request([]byte{0x02}) req = a.request([]byte{0x02})
return
case "3": case "3":
fmt.Println("Enter UID") UID := promptUID()
var b64UID string
fmt.Scanln(&b64UID)
UID, _ := base64.StdEncoding.DecodeString(b64UID)
req = a.request(append([]byte{0x03}, UID...)) req = a.request(append([]byte{0x03}, UID...))
return
case "4": case "4":
var uinfo UserInfo var uinfo UserInfo
var b64UID string var b64UID string
@ -118,10 +175,62 @@ func (a *administrator) getRequest() (req []byte, err error) {
fmt.Scanf("%d", &uinfo.ExpiryTime) fmt.Scanf("%d", &uinfo.ExpiryTime)
marshed, _ := json.Marshal(uinfo) marshed, _ := json.Marshal(uinfo)
req = a.request(append([]byte{0x04}, marshed...)) req = a.request(append([]byte{0x04}, marshed...))
case "5":
UID := promptUID()
fmt.Println("Are you sure to delete this user? y/n")
var ans string
fmt.Scanln(&ans)
if ans != "y" && ans != "Y" {
return return
}
req = a.request(append([]byte{0x05}, UID...))
case "6":
UID := promptUID()
req = a.request(append([]byte{0x06}, UID...))
case "7":
arg := make([]byte, 36)
copy(arg, promptUID())
copy(arg[32:], promptUint32("SessionsCap"))
req = a.request(append([]byte{0x07}, arg...))
case "8":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpRate"))
req = a.request(append([]byte{0x08}, arg...))
case "9":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownRate"))
req = a.request(append([]byte{0x09}, arg...))
case "10":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpCredit"))
req = a.request(append([]byte{0x0a}, arg...))
case "11":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownCredit"))
req = a.request(append([]byte{0x0b}, arg...))
case "12":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("ExpiryTime"))
req = a.request(append([]byte{0x0c}, arg...))
case "13":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("UpCredit to add"))
req = a.request(append([]byte{0x0d}, arg...))
case "14":
arg := make([]byte, 40)
copy(arg, promptUID())
copy(arg[32:], promptInt64("DownCredit to add"))
req = a.request(append([]byte{0x0e}, arg...))
default: default:
return nil, errors.New("Unreconised cmd") return nil, errors.New("Unreconised cmd")
} }
return req, nil
} }
// protocol: 0[TLS record layer 5 bytes]5[IV 16 bytes]21[data][hmac 32 bytes] // protocol: 0[TLS record layer 5 bytes]5[IV 16 bytes]21[data][hmac 32 bytes]

View File

@ -81,32 +81,6 @@ func makeRemoteConn(sta *client.State) (net.Conn, error) {
} }
func adminPrompt(sta *client.State) error {
a, err := adminHandshake(sta)
if err != nil {
return err
}
log.Println(err)
buf := make([]byte, 16000)
for {
req, err := a.getRequest()
if err != nil {
log.Println(err)
continue
}
a.adminConn.Write(req)
n, err := a.adminConn.Read(buf)
if err != nil {
return err
}
resp, err := a.checkAndDecrypt(buf[:n])
if err != nil {
return err
}
fmt.Println(string(resp))
}
}
func main() { func main() {
// Should be 127.0.0.1 to listen to ss-local on this machine // Should be 127.0.0.1 to listen to ss-local on this machine
var localHost string var localHost string

View File

@ -76,6 +76,7 @@ func (sb *switchboard) send(data []byte) (int, error) {
if err != nil { if err != nil {
return n, err return n, err
} }
sb.txWait(n)
if sb.AddTxCredit(-int64(n)) < 0 { if sb.AddTxCredit(-int64(n)) < 0 {
log.Println(ErrNoTxCredit) log.Println(ErrNoTxCredit)
defer sb.session.Close() defer sb.session.Close()

View File

@ -29,7 +29,7 @@ import (
10 setUpCredit uid credit ok 10 setUpCredit uid credit ok
11 setDownCredit uid credit ok 11 setDownCredit uid credit ok
12 setExpiryTime uid time ok 12 setExpiryTime uid time ok
13 addUpcredit uid delta ok 13 addUpCredit uid delta ok
14 addDownCredit uid delta ok 14 addDownCredit uid delta ok
*/ */
@ -42,7 +42,16 @@ func (up *Userpanel) MakeController(adminUID []byte) *controller {
return &controller{up, adminUID} return &controller{up, adminUID}
} }
func (c *controller) HandleRequest(req []byte) ([]byte, error) { var errInvalidArgument = errors.New("Invalid argument format")
func (c *controller) HandleRequest(req []byte) (resp []byte, err error) {
check := func(err error) []byte {
if err != nil {
return c.respond([]byte(err.Error()))
} else {
return c.respond([]byte("ok"))
}
}
plain, err := c.checkAndDecrypt(req) plain, err := c.checkAndDecrypt(req)
if err == ErrInvalidMac { if err == ErrInvalidMac {
log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\n raw request:\n%x\ndecrypted msg:\n%x", req, plain) log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\n raw request:\n%x\ndecrypted msg:\n%x", req, plain)
@ -52,55 +61,105 @@ func (c *controller) HandleRequest(req []byte) ([]byte, error) {
return c.respond([]byte(err.Error())), nil return c.respond([]byte(err.Error())), nil
} }
switch plain[0] { typ := plain[0]
var arg []byte
if len(plain) > 1 {
arg = plain[1:]
}
switch typ {
case 1: case 1:
UIDs := c.listActiveUsers() UIDs := c.listActiveUsers()
resp, _ := json.Marshal(UIDs) resp, _ = json.Marshal(UIDs)
return c.respond(resp), nil resp = c.respond(resp)
case 2: case 2:
uinfos := c.listAllUsers() uinfos := c.listAllUsers()
resp, _ := json.Marshal(uinfos) resp, _ = json.Marshal(uinfos)
return c.respond(resp), nil resp = c.respond(resp)
case 3: case 3:
uinfo, err := c.getUserInfo(plain[1:33]) uinfo, err := c.getUserInfo(arg)
if err != nil { if err != nil {
return c.respond([]byte(err.Error())), nil resp = c.respond([]byte(err.Error()))
break
} }
resp, _ := json.Marshal(uinfo) resp, _ = json.Marshal(uinfo)
return c.respond(resp), nil resp = c.respond(resp)
case 4: case 4:
var uinfo UserInfo var uinfo UserInfo
err = json.Unmarshal(plain[1:], &uinfo) err = json.Unmarshal(arg, &uinfo)
if err != nil { if err != nil {
return c.respond([]byte(err.Error())), nil resp = c.respond([]byte(err.Error()))
break
} }
err = c.addNewUser(uinfo) err = c.addNewUser(uinfo)
if err != nil { resp = check(err)
return c.respond([]byte(err.Error())), nil
} else {
return c.respond([]byte("ok")), nil
}
case 5: case 5:
err = c.delUser(plain[1:]) err = c.delUser(arg)
if err != nil { resp = check(err)
return c.respond([]byte(err.Error())), nil
} else {
return c.respond([]byte("ok")), nil
}
case 6: case 6:
err = c.syncMemFromDB(plain[1:33]) err = c.syncMemFromDB(arg)
if err != nil { resp = check(err)
return c.respond([]byte(err.Error())), nil case 7:
} else { if len(arg) < 36 {
return c.respond([]byte("ok")), nil resp = c.respond([]byte(errInvalidArgument.Error()))
break
} }
// TODO: implement the rest err = c.setSessionsCap(arg[0:32], Uint32(arg[32:36]))
resp = check(err)
case 8:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpRate(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 9:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownRate(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 10:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 11:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 12:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setExpiryTime(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 13:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addUpCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
case 14:
if len(arg) < 40 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addDownCredit(arg[0:32], int64(Uint64(arg[32:40])))
resp = check(err)
default: default:
return c.respond([]byte("Unsupported action")), nil return c.respond([]byte("Unsupported action")), nil
} }
return
} }

View File

@ -416,3 +416,50 @@ func (up *Userpanel) setExpiryTime(UID []byte, time int64) error {
u.setExpiryTime(time) u.setExpiryTime(time)
return nil return nil
} }
func (up *Userpanel) addUpCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("UpCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addUpCredit(delta)
return nil
}
func (up *Userpanel) addDownCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("DownCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addDownCredit(delta)
return nil
}