Rewrite user authentication, credit bookkeeping and db interaction

This commit is contained in:
Qian Wang 2019-07-22 13:42:39 +01:00
parent f66196d0c9
commit 29a45bcc1a
11 changed files with 407 additions and 856 deletions

View File

@ -1,7 +1,6 @@
package main
import (
"bytes"
"encoding/base64"
"flag"
"fmt"
@ -14,7 +13,6 @@ import (
mux "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/cbeuw/Cloak/internal/server"
"github.com/cbeuw/Cloak/internal/server/usermanager"
"github.com/cbeuw/Cloak/internal/util"
)
@ -103,41 +101,38 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
return nil
}
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
err = finishHandshake()
if err != nil {
log.Println(err)
return
}
c := sta.Userpanel.MakeController(sta.AdminUID)
for {
n, err := conn.Read(data)
/*
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
// added to the userinfo database. The distinction between going into the admin mode
// and normal proxy mode is that sessionID needs == 0 for admin mode
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
err = finishHandshake()
if err != nil {
log.Println(err)
return
}
resp, err := c.HandleRequest(data[:n])
if err != nil {
log.Println(err)
}
_, err = conn.Write(resp)
if err != nil {
log.Println(err)
return
c := sta.Userpanel.MakeController(sta.AdminUID)
for {
n, err := conn.Read(data)
if err != nil {
log.Println(err)
return
}
resp, err := c.HandleRequest(data[:n])
if err != nil {
log.Println(err)
}
_, err = conn.Write(resp)
if err != nil {
log.Println(err)
return
}
}
}
*/
}
var user *usermanager.User
if bytes.Equal(UID, sta.AdminUID) {
user, err = sta.Userpanel.GetAndActivateAdminUser(UID)
} else {
user, err = sta.Userpanel.GetAndActivateUser(UID)
}
user, err := sta.Panel.GetUser(UID)
if err != nil {
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
goWeb(data)
@ -278,10 +273,6 @@ func main() {
sta.ProxyBook["shadowsocks"] = ssLocalHost + ":" + ssLocalPort
}
if sta.AdminUID == nil {
log.Fatalln("AdminUID cannot be empty!")
}
go sta.UsedRandomCleaner()
listen := func(addr, port string) {

View File

@ -17,14 +17,14 @@ type Valve struct {
rxtb atomic.Value // *ratelimit.Bucket
txtb atomic.Value // *ratelimit.Bucket
rxCredit *int64
txCredit *int64
rx *int64
tx *int64
}
func MakeValve(rxRate, txRate int64, rxCredit, txCredit *int64) *Valve {
func MakeValve(rxRate, txRate int64, rx, tx *int64) *Valve {
v := &Valve{
rxCredit: rxCredit,
txCredit: txCredit,
rx: rx,
tx: tx,
}
v.SetRxRate(rxRate)
v.SetTxRate(txRate)
@ -35,13 +35,12 @@ func (v *Valve) SetRxRate(rate int64) { v.rxtb.Store(ratelimit.NewBucketWithRate
func (v *Valve) SetTxRate(rate int64) { v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
func (v *Valve) rxWait(n int) { v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
func (v *Valve) txWait(n int) { v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
func (v *Valve) SetRxCredit(n int64) { atomic.StoreInt64(v.rxCredit, n) }
func (v *Valve) SetTxCredit(n int64) { atomic.StoreInt64(v.txCredit, n) }
func (v *Valve) GetRxCredit() int64 { return atomic.LoadInt64(v.rxCredit) }
func (v *Valve) GetTxCredit() int64 { return atomic.LoadInt64(v.txCredit) }
// n can be negative
func (v *Valve) AddRxCredit(n int64) int64 { return atomic.AddInt64(v.rxCredit, n) }
// n can be negative
func (v *Valve) AddTxCredit(n int64) int64 { return atomic.AddInt64(v.txCredit, n) }
func (v *Valve) AddRx(n int64) { atomic.AddInt64(v.rx, n) }
func (v *Valve) AddTx(n int64) { atomic.AddInt64(v.tx, n) }
func (v *Valve) GetRx() int64 { return atomic.LoadInt64(v.rx) }
func (v *Valve) GetTx() int64 { return atomic.LoadInt64(v.tx) }
func (v *Valve) Nullify() (int64, int64) {
rx := atomic.SwapInt64(v.rx, 0)
tx := atomic.SwapInt64(v.tx, 0)
return rx, tx
}

View File

@ -50,9 +50,6 @@ func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
var errNilOptimum error = errors.New("The optimal connection is nil")
var ErrNoRxCredit error = errors.New("No Rx credit is left")
var ErrNoTxCredit error = errors.New("No Tx credit is left")
func (sb *switchboard) send(data []byte) (int, error) {
ce := sb.getOptimum()
if ce == nil {
@ -65,11 +62,7 @@ func (sb *switchboard) send(data []byte) (int, error) {
return n, err
}
sb.txWait(n)
if sb.AddTxCredit(-int64(n)) < 0 {
log.Println(ErrNoTxCredit)
go sb.session.Close()
return n, ErrNoTxCredit
}
sb.Valve.AddTx(int64(n))
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
go sb.updateOptimum()
return n, nil
@ -133,17 +126,14 @@ func (sb *switchboard) deplex(ce *connEnclave) {
for {
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
sb.rxWait(n)
sb.Valve.AddRx(int64(n))
if err != nil {
//log.Println(err)
go ce.remoteConn.Close()
sb.removeConn(ce)
return
}
if sb.AddRxCredit(-int64(n)) < 0 {
log.Println(ErrNoRxCredit)
sb.session.Close()
return
}
frame, err := sb.session.deobfs(buf[:n])
if err != nil {
log.Println(err)

View File

@ -0,0 +1,67 @@
package server
import (
"net"
"sync"
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
type ActiveUser struct {
panel *userPanel
arrUID [16]byte
valve *mux.Valve
sessionsM sync.RWMutex
sessions map[uint32]*mux.Session
}
func (u *ActiveUser) DelSession(sessionID uint32) {
u.sessionsM.Lock()
delete(u.sessions, sessionID)
if len(u.sessions) == 0 {
u.panel.updateUsageQueueForOne(u)
u.panel.activeUsersM.Lock()
delete(u.panel.activeUsers, u.arrUID)
u.panel.activeUsersM.Unlock()
}
u.sessionsM.Unlock()
}
func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock()
if sesh = u.sessions[sessionID]; sesh != nil {
u.sessionsM.Unlock()
return sesh, true, nil
} else {
err := u.panel.manager.authoriseNewSession(u)
if err != nil {
u.sessionsM.Unlock()
return nil, false, err
}
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
return sesh, false, nil
}
}
func (u *ActiveUser) Terminate() {
u.sessionsM.Lock()
for _, sesh := range u.sessions {
go sesh.Close()
}
u.sessionsM.Unlock()
u.panel.activeUsersM.Lock()
delete(u.panel.activeUsers, u.arrUID)
u.panel.activeUsersM.Unlock()
}
func (u *ActiveUser) NumSession() int {
u.sessionsM.RLock()
l := len(u.sessions)
u.sessionsM.RUnlock()
return l
}

View File

@ -8,8 +8,6 @@ import (
"io/ioutil"
"sync"
"time"
"github.com/cbeuw/Cloak/internal/server/usermanager"
)
type rawConfig struct {
@ -19,6 +17,7 @@ type rawConfig struct {
AdminUID string
DatabasePath string
BackupDirPath string
CncMode bool
}
// State type stores the global state of the program
@ -28,14 +27,16 @@ type State struct {
BindHost string
BindPort string
Now func() time.Time
AdminUID []byte
staticPv crypto.PrivateKey
Userpanel *usermanager.Userpanel
Now func() time.Time
AdminUID []byte
staticPv crypto.PrivateKey
RedirAddr string
usedRandomM sync.RWMutex
usedRandom map[[32]byte]int
RedirAddr string
Panel *userPanel
}
func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, error) {
@ -45,6 +46,7 @@ func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, err
Now: nowFunc,
}
ret.usedRandom = make(map[[32]byte]int)
go ret.UsedRandomCleaner()
return ret, nil
}
@ -66,11 +68,15 @@ func (sta *State) ParseConfig(conf string) (err error) {
}
}
up, err := usermanager.MakeUserpanel(preParse.DatabasePath, preParse.BackupDirPath)
if err != nil {
return errors.New("Attempting to open database: " + err.Error())
if preParse.CncMode {
} else {
manager, err := MakeLocalManager(preParse.DatabasePath)
if err != nil {
return err
}
sta.Panel = MakeUserPanel(manager)
}
sta.Userpanel = up
sta.RedirAddr = preParse.RedirAddr
sta.ProxyBook = preParse.ProxyBook

115
internal/server/um_local.go Normal file
View File

@ -0,0 +1,115 @@
package server
import (
"encoding/binary"
"log"
"time"
"github.com/boltdb/bolt"
)
var Uint32 = binary.BigEndian.Uint32
var Uint64 = binary.BigEndian.Uint64
var PutUint32 = binary.BigEndian.PutUint32
var PutUint64 = binary.BigEndian.PutUint64
type localManager struct {
db *bolt.DB
}
func MakeLocalManager(dbPath string) (*localManager, error) {
db, err := bolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
return &localManager{db}, nil
}
func (manager *localManager) authenticateUser(UID []byte) (int64, int64, error) {
var upRate, downRate, upCredit, downCredit, expiryTime int64
err := manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(UID)
if bucket == nil {
return ErrUserNotFound
}
upRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
downRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return 0, 0, err
}
if upCredit <= 0 {
return 0, 0, ErrNoUpCredit
}
if downCredit <= 0 {
return 0, 0, ErrNoDownCredit
}
if expiryTime < time.Now().Unix() {
return 0, 0, ErrUserExpired
}
return upRate, downRate, nil
}
func (manager *localManager) authoriseNewSession(user *ActiveUser) error {
var sessionsCap int
var upCredit, downCredit, expiryTime int64
err := manager.db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(user.arrUID[:])
if bucket == nil {
return ErrUserNotFound
}
sessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return err
}
if upCredit <= 0 {
return ErrNoUpCredit
}
if downCredit <= 0 {
return ErrNoDownCredit
}
if expiryTime < time.Now().Unix() {
return ErrUserExpired
}
//user.sessionsM.RLock()
if len(user.sessions) >= sessionsCap {
//user.sessionsM.RUnlock()
return ErrSessionsCapReached
}
//user.sessionsM.RUnlock()
return nil
}
func i64ToB(value int64) []byte {
oct := make([]byte, 8)
PutUint64(oct, uint64(value))
return oct
}
func (manager *localManager) uploadStatus(uploads []statusUpdate) error {
err := manager.db.Update(func(tx *bolt.Tx) error {
for _, status := range uploads {
bucket := tx.Bucket(status.UID)
if bucket == nil {
log.Printf("%x doesn't exist\n", status.UID)
continue
}
oldUp := int64(Uint64(bucket.Get([]byte("UpCredit"))))
bucket.Put([]byte("UpCredit"), i64ToB(oldUp-status.upUsage))
oldDown := int64(Uint64(bucket.Get([]byte("DownCredit"))))
bucket.Put([]byte("DownCredit"), i64ToB(oldDown-status.downUsage))
}
return nil
})
return err
}

View File

@ -0,0 +1,28 @@
package server
import (
"errors"
)
type statusUpdate struct {
UID []byte
active bool
numSession int
upUsage int64
downUsage int64
timestamp int64
}
var ErrUserNotFound = errors.New("UID does not correspond to a user")
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
var ErrNoUpCredit = errors.New("No upload credit left")
var ErrNoDownCredit = errors.New("No download credit left")
var ErrUserExpired = errors.New("User has expired")
type UserManager interface {
authenticateUser([]byte) (int64, int64, error)
authoriseNewSession(*ActiveUser) error
// TODO: fetch update's response
uploadStatus([]statusUpdate) error
}

View File

@ -1,212 +0,0 @@
package usermanager
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"errors"
"log"
)
// FIXME: sanity checks. The server may panic due to user input
// TODO: manual backup
/*
0 reserved
1 listActiveUsers none []uids
2 listAllUsers none []userinfo
3 getUserInfo uid userinfo
4 addNewUser userinfo ok
5 delUser uid ok
6 syncMemFromDB uid ok
7 setSessionsCap uid cap ok
8 setUpRate uid rate ok
9 setDownRate uid rate ok
10 setUpCredit uid credit ok
11 setDownCredit uid credit ok
12 setExpiryTime uid time ok
13 addUpCredit uid delta ok
14 addDownCredit uid delta ok
*/
type controller struct {
*Userpanel
adminUID []byte
}
func (up *Userpanel) MakeController(adminUID []byte) *controller {
return &controller{up, adminUID}
}
var errInvalidArgument = errors.New("Invalid argument format")
func (c *controller) HandleRequest(req []byte) (resp []byte, err error) {
check := func(err error) []byte {
if err != nil {
return c.respond([]byte(err.Error()))
} else {
return c.respond([]byte("ok"))
}
}
plain, err := c.checkAndDecrypt(req)
if err == ErrInvalidMac {
log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\naUID:%x\nraw request:\n%x\ndecrypted msg:\n%x", c.adminUID, req, plain)
return nil, err
} else if err != nil {
log.Printf("aUID:%x\n,err:%v\n", c.adminUID, err)
return c.respond([]byte(err.Error())), nil
}
typ := plain[0]
var arg []byte
if len(plain) > 1 {
arg = plain[1:]
}
switch typ {
case 1:
UIDs := c.listActiveUsers()
resp, _ = json.Marshal(UIDs)
resp = c.respond(resp)
case 2:
uinfos := c.listAllUsers()
resp, _ = json.Marshal(uinfos)
resp = c.respond(resp)
case 3:
uinfo, err := c.getUserInfo(arg)
if err != nil {
resp = c.respond([]byte(err.Error()))
break
}
resp, _ = json.Marshal(uinfo)
resp = c.respond(resp)
case 4:
var uinfo UserInfo
err = json.Unmarshal(arg, &uinfo)
if err != nil {
resp = c.respond([]byte(err.Error()))
break
}
err = c.addNewUser(uinfo)
resp = check(err)
case 5:
err = c.delUser(arg)
resp = check(err)
case 6:
err = c.syncMemFromDB(arg)
resp = check(err)
case 7:
if len(arg) < 20 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setSessionsCap(arg[0:16], Uint32(arg[16:20]))
resp = check(err)
case 8:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpRate(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 9:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownRate(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 10:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setUpCredit(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 11:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setDownCredit(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 12:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.setExpiryTime(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 13:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addUpCredit(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
case 14:
if len(arg) < 24 {
resp = c.respond([]byte(errInvalidArgument.Error()))
break
}
err = c.addDownCredit(arg[0:16], int64(Uint64(arg[16:24])))
resp = check(err)
default:
return c.respond([]byte("Unsupported action")), nil
}
return
}
var ErrInvalidMac = errors.New("Mac mismatch")
var errMsgTooShort = errors.New("Message length is less than 54")
// protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes]
func (c *controller) respond(resp []byte) []byte {
respLen := len(resp)
buf := make([]byte, 5+16+respLen+32)
buf[0] = 0x17
buf[1] = 0x03
buf[2] = 0x03
PutUint16(buf[3:5], uint16(16+respLen+32))
rand.Read(buf[5:21]) //iv
copy(buf[21:], resp)
block, _ := aes.NewCipher(c.adminUID)
stream := cipher.NewCTR(block, buf[5:21])
stream.XORKeyStream(buf[21:21+respLen], buf[21:21+respLen])
mac := hmac.New(sha256.New, c.adminUID)
mac.Write(buf[5 : 21+respLen])
copy(buf[21+respLen:], mac.Sum(nil))
return buf
}
func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) {
if len(data) < 54 {
return nil, errMsgTooShort
}
macIndex := len(data) - 32
mac := hmac.New(sha256.New, c.adminUID)
mac.Write(data[5:macIndex])
expected := mac.Sum(nil)
if !hmac.Equal(data[macIndex:], expected) {
return nil, ErrInvalidMac
}
iv := data[5:21]
ret := data[21:macIndex]
block, _ := aes.NewCipher(c.adminUID)
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(ret, ret)
return ret, nil
}

View File

@ -1,98 +0,0 @@
package usermanager
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
// for the ease of using json package
type UserInfo struct {
UID []byte
// ALL of the following fields have to be accessed atomically
SessionsCap uint32
UpRate int64
DownRate int64
UpCredit int64
DownCredit int64
ExpiryTime int64
}
type User struct {
up *Userpanel
arrUID [16]byte
*UserInfo
valve *mux.Valve
sessionsM sync.RWMutex
sessions map[uint32]*mux.Session
}
func MakeUser(up *Userpanel, uinfo *UserInfo) *User {
// this instance of valve is shared across ALL sessions of a user
valve := mux.MakeValve(uinfo.UpRate, uinfo.DownRate, &uinfo.UpCredit, &uinfo.DownCredit)
u := &User{
up: up,
UserInfo: uinfo,
valve: valve,
sessions: make(map[uint32]*mux.Session),
}
copy(u.arrUID[:], uinfo.UID)
return u
}
func (u *User) addUpCredit(delta int64) { u.valve.AddRxCredit(delta) }
func (u *User) addDownCredit(delta int64) { u.valve.AddTxCredit(delta) }
func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.SessionsCap, cap) }
func (u *User) setUpRate(rate int64) { u.valve.SetRxRate(rate) }
func (u *User) setDownRate(rate int64) { u.valve.SetTxRate(rate) }
func (u *User) setUpCredit(n int64) { u.valve.SetRxCredit(n) }
func (u *User) setDownCredit(n int64) { u.valve.SetTxCredit(n) }
func (u *User) setExpiryTime(time int64) { atomic.StoreInt64(&u.ExpiryTime, time) }
func (u *User) updateInfo(uinfo UserInfo) {
u.setSessionsCap(uinfo.SessionsCap)
u.setUpCredit(uinfo.UpCredit)
u.setDownCredit(uinfo.DownCredit)
u.setUpRate(uinfo.UpRate)
u.setDownRate(uinfo.DownRate)
u.setExpiryTime(uinfo.ExpiryTime)
}
func (u *User) DelSession(sessionID uint32) {
u.sessionsM.Lock()
delete(u.sessions, sessionID)
if len(u.sessions) == 0 {
u.sessionsM.Unlock()
u.up.delActiveUser(u.UID)
return
}
u.sessionsM.Unlock()
}
func (u *User) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
if time.Now().Unix() > u.ExpiryTime {
return nil, false, errors.New("Expiry time passed")
}
u.sessionsM.Lock()
if sesh = u.sessions[sessionID]; sesh != nil {
u.sessionsM.Unlock()
return sesh, true, nil
} else {
if len(u.sessions) >= int(u.SessionsCap) {
u.sessionsM.Unlock()
return nil, false, errors.New("SessionsCap reached")
}
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
return sesh, false, nil
}
}

View File

@ -1,473 +0,0 @@
package usermanager
import (
"encoding/base64"
"encoding/binary"
"errors"
"os"
"path"
"strconv"
"sync"
"time"
"github.com/boltdb/bolt"
)
var Uint32 = binary.BigEndian.Uint32
var Uint64 = binary.BigEndian.Uint64
var PutUint16 = binary.BigEndian.PutUint16
var PutUint32 = binary.BigEndian.PutUint32
var PutUint64 = binary.BigEndian.PutUint64
type Userpanel struct {
db *bolt.DB
bakRoot string
activeUsersM sync.RWMutex
activeUsers map[[16]byte]*User
}
func MakeUserpanel(dbPath, bakRoot string) (*Userpanel, error) {
db, err := bolt.Open(dbPath, 0600, nil)
if err != nil {
return nil, err
}
if bakRoot == "" {
os.Mkdir("db-backup", 0777)
bakRoot = "db-backup"
}
bakRoot = path.Clean(bakRoot)
up := &Userpanel{
db: db,
bakRoot: bakRoot,
activeUsers: make(map[[16]byte]*User),
}
go func() {
for {
time.Sleep(time.Second * 10)
up.updateCredits()
}
}()
return up, nil
}
// credits of all users are updated together so that there is only 1 goroutine managing it
func (up *Userpanel) updateCredits() {
up.activeUsersM.RLock()
for _, u := range up.activeUsers {
up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(u.arrUID[:])
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte("UpCredit"), i64ToB(u.valve.GetRxCredit())); err != nil {
return err
}
if err := b.Put([]byte("DownCredit"), i64ToB(u.valve.GetTxCredit())); err != nil {
return err
}
return nil
})
}
up.activeUsersM.RUnlock()
}
func (up *Userpanel) backupDB(bakFileName string) error {
bakPath := up.bakRoot + "/" + bakFileName
_, err := os.Stat(bakPath)
if err == nil {
return errors.New("Attempting to overwrite a file during backup!")
}
var bak *os.File
if os.IsNotExist(err) {
bak, err = os.Create(bakPath)
if err != nil {
return err
}
}
err = up.db.View(func(tx *bolt.Tx) error {
_, err := tx.WriteTo(bak)
if err != nil {
return err
}
return nil
})
return err
}
var ErrUserNotFound = errors.New("User does not exist in db")
var ErrUserNotActive = errors.New("User is not active")
func (up *Userpanel) GetAndActivateAdminUser(AdminUID []byte) (*User, error) {
up.activeUsersM.Lock()
var arrUID [16]byte
copy(arrUID[:], AdminUID)
if user, ok := up.activeUsers[arrUID]; ok {
up.activeUsersM.Unlock()
return user, nil
}
uinfo := UserInfo{
UID: AdminUID,
SessionsCap: 1e9,
UpRate: 1e12,
DownRate: 1e12,
UpCredit: 1e15,
DownCredit: 1e15,
ExpiryTime: 1e15,
}
user := MakeUser(up, &uinfo)
up.activeUsers[arrUID] = user
up.activeUsersM.Unlock()
return user, nil
}
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's info
// from the db and mark it as an active user
func (up *Userpanel) GetAndActivateUser(UID []byte) (*User, error) {
up.activeUsersM.Lock()
var arrUID [16]byte
copy(arrUID[:], UID)
if user, ok := up.activeUsers[arrUID]; ok {
up.activeUsersM.Unlock()
return user, nil
}
var uinfo UserInfo
uinfo.UID = UID
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID[:])
if b == nil {
return ErrUserNotFound
}
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) // reee brackets
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
up.activeUsersM.Unlock()
return nil, err
}
u := MakeUser(up, &uinfo)
up.activeUsers[arrUID] = u
up.activeUsersM.Unlock()
return u, nil
}
func (up *Userpanel) updateDBEntryUint32(UID []byte, key string, value uint32) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte(key), u32ToB(value)); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) updateDBEntryInt64(UID []byte, key string, value int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
if err := b.Put([]byte(key), i64ToB(value)); err != nil {
return err
}
return nil
})
return err
}
// This is used when all sessions of a user close
func (up *Userpanel) delActiveUser(UID []byte) {
var arrUID [16]byte
copy(arrUID[:], UID)
up.activeUsersM.Lock()
delete(up.activeUsers, arrUID)
up.activeUsersM.Unlock()
}
func (up *Userpanel) getActiveUser(UID []byte) *User {
var arrUID [16]byte
copy(arrUID[:], UID)
up.activeUsersM.RLock()
ret := up.activeUsers[arrUID]
up.activeUsersM.RUnlock()
return ret
}
// below are remote control utilised functions
func (up *Userpanel) listActiveUsers() [][]byte {
var ret [][]byte
up.activeUsersM.RLock()
for _, u := range up.activeUsers {
ret = append(ret, u.UID)
}
up.activeUsersM.RUnlock()
return ret
}
func (up *Userpanel) listAllUsers() []UserInfo {
var ret []UserInfo
up.db.View(func(tx *bolt.Tx) error {
tx.ForEach(func(UID []byte, b *bolt.Bucket) error {
// if we want to avoid writing every single key out,
// we would have to either make UserInfo a map,
// or use reflect.
// neither is convinient
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
ret = append(ret, uinfo)
return nil
})
return nil
})
return ret
}
func (up *Userpanel) getUserInfo(UID []byte) (UserInfo, error) {
var uinfo UserInfo
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
return uinfo, err
}
// In boltdb, the value argument for bucket.Put has to be valid for the duration
// of the transaction.
// This basically means that you cannot reuse a byte slice for two different keys
// in a transaction. So we need to allocate a fresh byte slice for each value
func u32ToB(value uint32) []byte {
quad := make([]byte, 4)
PutUint32(quad, value)
return quad
}
func i64ToB(value int64) []byte {
oct := make([]byte, 8)
PutUint64(oct, uint64(value))
return oct
}
func (up *Userpanel) addNewUser(uinfo UserInfo) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucket(uinfo.UID[:])
if err != nil {
return err
}
if err = b.Put([]byte("SessionsCap"), u32ToB(uinfo.SessionsCap)); err != nil {
return err
}
if err = b.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
return err
}
if err = b.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
return err
}
if err = b.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
return err
}
if err = b.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
return err
}
if err = b.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
return err
}
return nil
})
return err
}
func (up *Userpanel) delUser(UID []byte) error {
err := up.backupDB(strconv.FormatInt(time.Now().Unix(), 10) + "_pre_del_" + base64.StdEncoding.EncodeToString(UID) + ".bak")
if err != nil {
return err
}
err = up.db.Update(func(tx *bolt.Tx) error {
return tx.DeleteBucket(UID)
})
return err
}
func (up *Userpanel) syncMemFromDB(UID []byte) error {
var uinfo UserInfo
err := up.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return ErrUserNotActive
}
u.updateInfo(uinfo)
return nil
}
// the following functions will update the db entries first, then if the
// user is active, it will update it in memory.
func (up *Userpanel) setSessionsCap(UID []byte, cap uint32) error {
err := up.updateDBEntryUint32(UID, "SessionsCap", cap)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setSessionsCap(cap)
return nil
}
func (up *Userpanel) setUpRate(UID []byte, rate int64) error {
err := up.updateDBEntryInt64(UID, "UpRate", rate)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setUpRate(rate)
return nil
}
func (up *Userpanel) setDownRate(UID []byte, rate int64) error {
err := up.updateDBEntryInt64(UID, "DownRate", rate)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setDownRate(rate)
return nil
}
func (up *Userpanel) setUpCredit(UID []byte, n int64) error {
err := up.updateDBEntryInt64(UID, "UpCredit", n)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setUpCredit(n)
return nil
}
func (up *Userpanel) setDownCredit(UID []byte, n int64) error {
err := up.updateDBEntryInt64(UID, "DownCredit", n)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setDownCredit(n)
return nil
}
func (up *Userpanel) setExpiryTime(UID []byte, time int64) error {
err := up.updateDBEntryInt64(UID, "ExpiryTime", time)
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.setExpiryTime(time)
return nil
}
func (up *Userpanel) addUpCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("UpCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addUpCredit(delta)
return nil
}
func (up *Userpanel) addDownCredit(UID []byte, delta int64) error {
err := up.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(UID)
if b == nil {
return ErrUserNotFound
}
old := b.Get([]byte("DownCredit"))
new := int64(Uint64(old)) + delta
if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
u := up.getActiveUser(UID)
if u == nil {
return nil
}
u.addDownCredit(delta)
return nil
}

View File

@ -0,0 +1,138 @@
package server
import (
"sync"
"sync/atomic"
"time"
mux "github.com/cbeuw/Cloak/internal/multiplex"
)
type userPanel struct {
manager UserManager
activeUsersM sync.RWMutex
activeUsers map[[16]byte]*ActiveUser
usageUpdateQueueM sync.Mutex
usageUpdateQueue map[[16]byte]*usagePair
}
func MakeUserPanel(manager UserManager) *userPanel {
ret := &userPanel{
manager: manager,
activeUsers: make(map[[16]byte]*ActiveUser),
usageUpdateQueue: make(map[[16]byte]*usagePair),
}
go ret.regularQueueUpload()
return ret
}
func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
panel.activeUsersM.Lock()
var arrUID [16]byte
copy(arrUID[:], UID)
if user, ok := panel.activeUsers[arrUID]; ok {
panel.activeUsersM.Unlock()
return user, nil
}
upRate, downRate, err := panel.manager.authenticateUser(UID)
if err != nil {
panel.activeUsersM.Unlock()
return nil, err
}
var upUsage, downUsage int64
valve := mux.MakeValve(upRate, downRate, &upUsage, &downUsage)
user := &ActiveUser{
panel: panel,
valve: valve,
sessions: make(map[uint32]*mux.Session),
}
copy(user.arrUID[:], UID)
panel.activeUsers[user.arrUID] = user
panel.activeUsersM.Unlock()
return user, nil
}
func (panel *userPanel) isActive(UID []byte) bool {
var arrUID [16]byte
copy(arrUID[:], UID)
panel.activeUsersM.RLock()
_, ok := panel.activeUsers[arrUID]
panel.activeUsersM.RUnlock()
return ok
}
type usagePair struct {
up *int64
down *int64
}
func (panel *userPanel) updateUsageQueue() {
panel.activeUsersM.Lock()
panel.usageUpdateQueueM.Lock()
for _, user := range panel.activeUsers {
upIncured, downIncured := user.valve.Nullify()
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
atomic.AddInt64(usage.up, upIncured)
atomic.AddInt64(usage.down, downIncured)
} else {
// if the user hasn't been added to the queue
usage = &usagePair{&upIncured, &downIncured}
panel.usageUpdateQueue[user.arrUID] = usage
}
}
panel.activeUsersM.Unlock()
panel.usageUpdateQueueM.Unlock()
}
func (panel *userPanel) updateUsageQueueForOne(user *ActiveUser) {
// used when one particular user deactivates
upIncured, downIncured := user.valve.Nullify()
panel.usageUpdateQueueM.Lock()
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
atomic.AddInt64(usage.up, upIncured)
atomic.AddInt64(usage.down, downIncured)
} else {
usage = &usagePair{&upIncured, &downIncured}
panel.usageUpdateQueue[user.arrUID] = usage
}
panel.usageUpdateQueueM.Unlock()
}
func (panel *userPanel) commitUpdate() {
panel.usageUpdateQueueM.Lock()
statuses := make([]statusUpdate, 0, len(panel.usageUpdateQueue))
for arrUID, usage := range panel.usageUpdateQueue {
panel.activeUsersM.RLock()
user := panel.activeUsers[arrUID]
panel.activeUsersM.RUnlock()
var numSession int
if user != nil {
numSession = user.NumSession()
}
status := statusUpdate{
UID: arrUID[:],
active: panel.isActive(arrUID[:]),
numSession: numSession,
upUsage: *usage.up,
downUsage: *usage.down,
timestamp: time.Now().Unix(),
}
statuses = append(statuses, status)
}
panel.manager.uploadStatus(statuses)
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
panel.usageUpdateQueueM.Unlock()
}
func (panel *userPanel) regularQueueUpload() {
for {
time.Sleep(1 * time.Minute)
go func() {
panel.updateUsageQueue()
panel.commitUpdate()
}()
}
}