mirror of https://github.com/cbeuw/Cloak
Rewrite user authentication, credit bookkeeping and db interaction
This commit is contained in:
parent
f66196d0c9
commit
29a45bcc1a
|
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
|
@ -14,7 +13,6 @@ import (
|
|||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/cbeuw/Cloak/internal/server"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
"github.com/cbeuw/Cloak/internal/util"
|
||||
)
|
||||
|
||||
|
|
@ -103,41 +101,38 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
|
||||
// added to the userinfo database. The distinction between going into the admin mode
|
||||
// and normal proxy mode is that sessionID needs == 0 for admin mode
|
||||
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
|
||||
err = finishHandshake()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
c := sta.Userpanel.MakeController(sta.AdminUID)
|
||||
for {
|
||||
n, err := conn.Read(data)
|
||||
/*
|
||||
// adminUID can use the server as normal with unlimited QoS credits. The adminUID is not
|
||||
// added to the userinfo database. The distinction between going into the admin mode
|
||||
// and normal proxy mode is that sessionID needs == 0 for admin mode
|
||||
if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 {
|
||||
err = finishHandshake()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
resp, err := c.HandleRequest(data[:n])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
_, err = conn.Write(resp)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
c := sta.Userpanel.MakeController(sta.AdminUID)
|
||||
for {
|
||||
n, err := conn.Read(data)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
resp, err := c.HandleRequest(data[:n])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
_, err = conn.Write(resp)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
var user *usermanager.User
|
||||
if bytes.Equal(UID, sta.AdminUID) {
|
||||
user, err = sta.Userpanel.GetAndActivateAdminUser(UID)
|
||||
} else {
|
||||
user, err = sta.Userpanel.GetAndActivateUser(UID)
|
||||
}
|
||||
user, err := sta.Panel.GetUser(UID)
|
||||
if err != nil {
|
||||
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
|
||||
goWeb(data)
|
||||
|
|
@ -278,10 +273,6 @@ func main() {
|
|||
sta.ProxyBook["shadowsocks"] = ssLocalHost + ":" + ssLocalPort
|
||||
}
|
||||
|
||||
if sta.AdminUID == nil {
|
||||
log.Fatalln("AdminUID cannot be empty!")
|
||||
}
|
||||
|
||||
go sta.UsedRandomCleaner()
|
||||
|
||||
listen := func(addr, port string) {
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ type Valve struct {
|
|||
rxtb atomic.Value // *ratelimit.Bucket
|
||||
txtb atomic.Value // *ratelimit.Bucket
|
||||
|
||||
rxCredit *int64
|
||||
txCredit *int64
|
||||
rx *int64
|
||||
tx *int64
|
||||
}
|
||||
|
||||
func MakeValve(rxRate, txRate int64, rxCredit, txCredit *int64) *Valve {
|
||||
func MakeValve(rxRate, txRate int64, rx, tx *int64) *Valve {
|
||||
v := &Valve{
|
||||
rxCredit: rxCredit,
|
||||
txCredit: txCredit,
|
||||
rx: rx,
|
||||
tx: tx,
|
||||
}
|
||||
v.SetRxRate(rxRate)
|
||||
v.SetTxRate(txRate)
|
||||
|
|
@ -35,13 +35,12 @@ func (v *Valve) SetRxRate(rate int64) { v.rxtb.Store(ratelimit.NewBucketWithRate
|
|||
func (v *Valve) SetTxRate(rate int64) { v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) }
|
||||
func (v *Valve) rxWait(n int) { v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
|
||||
func (v *Valve) txWait(n int) { v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) }
|
||||
func (v *Valve) SetRxCredit(n int64) { atomic.StoreInt64(v.rxCredit, n) }
|
||||
func (v *Valve) SetTxCredit(n int64) { atomic.StoreInt64(v.txCredit, n) }
|
||||
func (v *Valve) GetRxCredit() int64 { return atomic.LoadInt64(v.rxCredit) }
|
||||
func (v *Valve) GetTxCredit() int64 { return atomic.LoadInt64(v.txCredit) }
|
||||
|
||||
// n can be negative
|
||||
func (v *Valve) AddRxCredit(n int64) int64 { return atomic.AddInt64(v.rxCredit, n) }
|
||||
|
||||
// n can be negative
|
||||
func (v *Valve) AddTxCredit(n int64) int64 { return atomic.AddInt64(v.txCredit, n) }
|
||||
func (v *Valve) AddRx(n int64) { atomic.AddInt64(v.rx, n) }
|
||||
func (v *Valve) AddTx(n int64) { atomic.AddInt64(v.tx, n) }
|
||||
func (v *Valve) GetRx() int64 { return atomic.LoadInt64(v.rx) }
|
||||
func (v *Valve) GetTx() int64 { return atomic.LoadInt64(v.tx) }
|
||||
func (v *Valve) Nullify() (int64, int64) {
|
||||
rx := atomic.SwapInt64(v.rx, 0)
|
||||
tx := atomic.SwapInt64(v.tx, 0)
|
||||
return rx, tx
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,9 +50,6 @@ func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
|
|||
|
||||
var errNilOptimum error = errors.New("The optimal connection is nil")
|
||||
|
||||
var ErrNoRxCredit error = errors.New("No Rx credit is left")
|
||||
var ErrNoTxCredit error = errors.New("No Tx credit is left")
|
||||
|
||||
func (sb *switchboard) send(data []byte) (int, error) {
|
||||
ce := sb.getOptimum()
|
||||
if ce == nil {
|
||||
|
|
@ -65,11 +62,7 @@ func (sb *switchboard) send(data []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
sb.txWait(n)
|
||||
if sb.AddTxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoTxCredit)
|
||||
go sb.session.Close()
|
||||
return n, ErrNoTxCredit
|
||||
}
|
||||
sb.Valve.AddTx(int64(n))
|
||||
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
|
||||
go sb.updateOptimum()
|
||||
return n, nil
|
||||
|
|
@ -133,17 +126,14 @@ func (sb *switchboard) deplex(ce *connEnclave) {
|
|||
for {
|
||||
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
|
||||
sb.rxWait(n)
|
||||
sb.Valve.AddRx(int64(n))
|
||||
if err != nil {
|
||||
//log.Println(err)
|
||||
go ce.remoteConn.Close()
|
||||
sb.removeConn(ce)
|
||||
return
|
||||
}
|
||||
if sb.AddRxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoRxCredit)
|
||||
sb.session.Close()
|
||||
return
|
||||
}
|
||||
|
||||
frame, err := sb.session.deobfs(buf[:n])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
type ActiveUser struct {
|
||||
panel *userPanel
|
||||
|
||||
arrUID [16]byte
|
||||
|
||||
valve *mux.Valve
|
||||
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[uint32]*mux.Session
|
||||
}
|
||||
|
||||
func (u *ActiveUser) DelSession(sessionID uint32) {
|
||||
u.sessionsM.Lock()
|
||||
delete(u.sessions, sessionID)
|
||||
if len(u.sessions) == 0 {
|
||||
u.panel.updateUsageQueueForOne(u)
|
||||
u.panel.activeUsersM.Lock()
|
||||
delete(u.panel.activeUsers, u.arrUID)
|
||||
u.panel.activeUsersM.Unlock()
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
|
||||
u.sessionsM.Lock()
|
||||
if sesh = u.sessions[sessionID]; sesh != nil {
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, true, nil
|
||||
} else {
|
||||
err := u.panel.manager.authoriseNewSession(u)
|
||||
if err != nil {
|
||||
u.sessionsM.Unlock()
|
||||
return nil, false, err
|
||||
}
|
||||
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
|
||||
u.sessions[sessionID] = sesh
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *ActiveUser) Terminate() {
|
||||
u.sessionsM.Lock()
|
||||
for _, sesh := range u.sessions {
|
||||
go sesh.Close()
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
u.panel.activeUsersM.Lock()
|
||||
delete(u.panel.activeUsers, u.arrUID)
|
||||
u.panel.activeUsersM.Unlock()
|
||||
}
|
||||
|
||||
func (u *ActiveUser) NumSession() int {
|
||||
u.sessionsM.RLock()
|
||||
l := len(u.sessions)
|
||||
u.sessionsM.RUnlock()
|
||||
return l
|
||||
}
|
||||
|
|
@ -8,8 +8,6 @@ import (
|
|||
"io/ioutil"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
)
|
||||
|
||||
type rawConfig struct {
|
||||
|
|
@ -19,6 +17,7 @@ type rawConfig struct {
|
|||
AdminUID string
|
||||
DatabasePath string
|
||||
BackupDirPath string
|
||||
CncMode bool
|
||||
}
|
||||
|
||||
// State type stores the global state of the program
|
||||
|
|
@ -28,14 +27,16 @@ type State struct {
|
|||
BindHost string
|
||||
BindPort string
|
||||
|
||||
Now func() time.Time
|
||||
AdminUID []byte
|
||||
staticPv crypto.PrivateKey
|
||||
Userpanel *usermanager.Userpanel
|
||||
Now func() time.Time
|
||||
AdminUID []byte
|
||||
staticPv crypto.PrivateKey
|
||||
|
||||
RedirAddr string
|
||||
|
||||
usedRandomM sync.RWMutex
|
||||
usedRandom map[[32]byte]int
|
||||
|
||||
RedirAddr string
|
||||
Panel *userPanel
|
||||
}
|
||||
|
||||
func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, error) {
|
||||
|
|
@ -45,6 +46,7 @@ func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, err
|
|||
Now: nowFunc,
|
||||
}
|
||||
ret.usedRandom = make(map[[32]byte]int)
|
||||
go ret.UsedRandomCleaner()
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
|
|
@ -66,11 +68,15 @@ func (sta *State) ParseConfig(conf string) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
up, err := usermanager.MakeUserpanel(preParse.DatabasePath, preParse.BackupDirPath)
|
||||
if err != nil {
|
||||
return errors.New("Attempting to open database: " + err.Error())
|
||||
if preParse.CncMode {
|
||||
|
||||
} else {
|
||||
manager, err := MakeLocalManager(preParse.DatabasePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sta.Panel = MakeUserPanel(manager)
|
||||
}
|
||||
sta.Userpanel = up
|
||||
|
||||
sta.RedirAddr = preParse.RedirAddr
|
||||
sta.ProxyBook = preParse.ProxyBook
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
var Uint32 = binary.BigEndian.Uint32
|
||||
var Uint64 = binary.BigEndian.Uint64
|
||||
var PutUint32 = binary.BigEndian.PutUint32
|
||||
var PutUint64 = binary.BigEndian.PutUint64
|
||||
|
||||
type localManager struct {
|
||||
db *bolt.DB
|
||||
}
|
||||
|
||||
func MakeLocalManager(dbPath string) (*localManager, error) {
|
||||
db, err := bolt.Open(dbPath, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &localManager{db}, nil
|
||||
}
|
||||
|
||||
func (manager *localManager) authenticateUser(UID []byte) (int64, int64, error) {
|
||||
var upRate, downRate, upCredit, downCredit, expiryTime int64
|
||||
err := manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(UID)
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
upRate = int64(Uint64(bucket.Get([]byte("UpRate"))))
|
||||
downRate = int64(Uint64(bucket.Get([]byte("DownRate"))))
|
||||
upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
|
||||
downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
|
||||
expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if upCredit <= 0 {
|
||||
return 0, 0, ErrNoUpCredit
|
||||
}
|
||||
if downCredit <= 0 {
|
||||
return 0, 0, ErrNoDownCredit
|
||||
}
|
||||
if expiryTime < time.Now().Unix() {
|
||||
return 0, 0, ErrUserExpired
|
||||
}
|
||||
|
||||
return upRate, downRate, nil
|
||||
}
|
||||
|
||||
func (manager *localManager) authoriseNewSession(user *ActiveUser) error {
|
||||
var sessionsCap int
|
||||
var upCredit, downCredit, expiryTime int64
|
||||
err := manager.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(user.arrUID[:])
|
||||
if bucket == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
sessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap"))))
|
||||
upCredit = int64(Uint64(bucket.Get([]byte("UpCredit"))))
|
||||
downCredit = int64(Uint64(bucket.Get([]byte("DownCredit"))))
|
||||
expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if upCredit <= 0 {
|
||||
return ErrNoUpCredit
|
||||
}
|
||||
if downCredit <= 0 {
|
||||
return ErrNoDownCredit
|
||||
}
|
||||
if expiryTime < time.Now().Unix() {
|
||||
return ErrUserExpired
|
||||
}
|
||||
//user.sessionsM.RLock()
|
||||
if len(user.sessions) >= sessionsCap {
|
||||
//user.sessionsM.RUnlock()
|
||||
return ErrSessionsCapReached
|
||||
}
|
||||
//user.sessionsM.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func i64ToB(value int64) []byte {
|
||||
oct := make([]byte, 8)
|
||||
PutUint64(oct, uint64(value))
|
||||
return oct
|
||||
}
|
||||
|
||||
func (manager *localManager) uploadStatus(uploads []statusUpdate) error {
|
||||
err := manager.db.Update(func(tx *bolt.Tx) error {
|
||||
for _, status := range uploads {
|
||||
bucket := tx.Bucket(status.UID)
|
||||
if bucket == nil {
|
||||
log.Printf("%x doesn't exist\n", status.UID)
|
||||
continue
|
||||
}
|
||||
oldUp := int64(Uint64(bucket.Get([]byte("UpCredit"))))
|
||||
bucket.Put([]byte("UpCredit"), i64ToB(oldUp-status.upUsage))
|
||||
oldDown := int64(Uint64(bucket.Get([]byte("DownCredit"))))
|
||||
bucket.Put([]byte("DownCredit"), i64ToB(oldDown-status.downUsage))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type statusUpdate struct {
|
||||
UID []byte
|
||||
active bool
|
||||
numSession int
|
||||
|
||||
upUsage int64
|
||||
downUsage int64
|
||||
timestamp int64
|
||||
}
|
||||
|
||||
var ErrUserNotFound = errors.New("UID does not correspond to a user")
|
||||
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
|
||||
var ErrNoUpCredit = errors.New("No upload credit left")
|
||||
var ErrNoDownCredit = errors.New("No download credit left")
|
||||
var ErrUserExpired = errors.New("User has expired")
|
||||
|
||||
type UserManager interface {
|
||||
authenticateUser([]byte) (int64, int64, error)
|
||||
authoriseNewSession(*ActiveUser) error
|
||||
// TODO: fetch update's response
|
||||
uploadStatus([]statusUpdate) error
|
||||
}
|
||||
|
|
@ -1,212 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
)
|
||||
|
||||
// FIXME: sanity checks. The server may panic due to user input
|
||||
|
||||
// TODO: manual backup
|
||||
|
||||
/*
|
||||
0 reserved
|
||||
1 listActiveUsers none []uids
|
||||
2 listAllUsers none []userinfo
|
||||
3 getUserInfo uid userinfo
|
||||
|
||||
4 addNewUser userinfo ok
|
||||
5 delUser uid ok
|
||||
6 syncMemFromDB uid ok
|
||||
|
||||
7 setSessionsCap uid cap ok
|
||||
8 setUpRate uid rate ok
|
||||
9 setDownRate uid rate ok
|
||||
10 setUpCredit uid credit ok
|
||||
11 setDownCredit uid credit ok
|
||||
12 setExpiryTime uid time ok
|
||||
13 addUpCredit uid delta ok
|
||||
14 addDownCredit uid delta ok
|
||||
*/
|
||||
|
||||
type controller struct {
|
||||
*Userpanel
|
||||
adminUID []byte
|
||||
}
|
||||
|
||||
func (up *Userpanel) MakeController(adminUID []byte) *controller {
|
||||
return &controller{up, adminUID}
|
||||
}
|
||||
|
||||
var errInvalidArgument = errors.New("Invalid argument format")
|
||||
|
||||
func (c *controller) HandleRequest(req []byte) (resp []byte, err error) {
|
||||
check := func(err error) []byte {
|
||||
if err != nil {
|
||||
return c.respond([]byte(err.Error()))
|
||||
} else {
|
||||
return c.respond([]byte("ok"))
|
||||
}
|
||||
}
|
||||
plain, err := c.checkAndDecrypt(req)
|
||||
if err == ErrInvalidMac {
|
||||
log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\naUID:%x\nraw request:\n%x\ndecrypted msg:\n%x", c.adminUID, req, plain)
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Printf("aUID:%x\n,err:%v\n", c.adminUID, err)
|
||||
return c.respond([]byte(err.Error())), nil
|
||||
}
|
||||
|
||||
typ := plain[0]
|
||||
var arg []byte
|
||||
if len(plain) > 1 {
|
||||
arg = plain[1:]
|
||||
}
|
||||
switch typ {
|
||||
case 1:
|
||||
UIDs := c.listActiveUsers()
|
||||
resp, _ = json.Marshal(UIDs)
|
||||
resp = c.respond(resp)
|
||||
case 2:
|
||||
uinfos := c.listAllUsers()
|
||||
resp, _ = json.Marshal(uinfos)
|
||||
resp = c.respond(resp)
|
||||
case 3:
|
||||
uinfo, err := c.getUserInfo(arg)
|
||||
if err != nil {
|
||||
resp = c.respond([]byte(err.Error()))
|
||||
break
|
||||
}
|
||||
resp, _ = json.Marshal(uinfo)
|
||||
resp = c.respond(resp)
|
||||
case 4:
|
||||
var uinfo UserInfo
|
||||
err = json.Unmarshal(arg, &uinfo)
|
||||
if err != nil {
|
||||
resp = c.respond([]byte(err.Error()))
|
||||
break
|
||||
}
|
||||
|
||||
err = c.addNewUser(uinfo)
|
||||
resp = check(err)
|
||||
case 5:
|
||||
err = c.delUser(arg)
|
||||
resp = check(err)
|
||||
case 6:
|
||||
err = c.syncMemFromDB(arg)
|
||||
resp = check(err)
|
||||
case 7:
|
||||
if len(arg) < 20 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setSessionsCap(arg[0:16], Uint32(arg[16:20]))
|
||||
resp = check(err)
|
||||
case 8:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setUpRate(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 9:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setDownRate(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 10:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setUpCredit(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 11:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setDownCredit(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 12:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.setExpiryTime(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 13:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.addUpCredit(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
case 14:
|
||||
if len(arg) < 24 {
|
||||
resp = c.respond([]byte(errInvalidArgument.Error()))
|
||||
break
|
||||
}
|
||||
err = c.addDownCredit(arg[0:16], int64(Uint64(arg[16:24])))
|
||||
resp = check(err)
|
||||
default:
|
||||
return c.respond([]byte("Unsupported action")), nil
|
||||
|
||||
}
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
var ErrInvalidMac = errors.New("Mac mismatch")
|
||||
var errMsgTooShort = errors.New("Message length is less than 54")
|
||||
|
||||
// protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes]
|
||||
func (c *controller) respond(resp []byte) []byte {
|
||||
respLen := len(resp)
|
||||
|
||||
buf := make([]byte, 5+16+respLen+32)
|
||||
buf[0] = 0x17
|
||||
buf[1] = 0x03
|
||||
buf[2] = 0x03
|
||||
PutUint16(buf[3:5], uint16(16+respLen+32))
|
||||
|
||||
rand.Read(buf[5:21]) //iv
|
||||
copy(buf[21:], resp)
|
||||
block, _ := aes.NewCipher(c.adminUID)
|
||||
stream := cipher.NewCTR(block, buf[5:21])
|
||||
stream.XORKeyStream(buf[21:21+respLen], buf[21:21+respLen])
|
||||
|
||||
mac := hmac.New(sha256.New, c.adminUID)
|
||||
mac.Write(buf[5 : 21+respLen])
|
||||
copy(buf[21+respLen:], mac.Sum(nil))
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) {
|
||||
if len(data) < 54 {
|
||||
return nil, errMsgTooShort
|
||||
}
|
||||
macIndex := len(data) - 32
|
||||
mac := hmac.New(sha256.New, c.adminUID)
|
||||
mac.Write(data[5:macIndex])
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(data[macIndex:], expected) {
|
||||
return nil, ErrInvalidMac
|
||||
}
|
||||
|
||||
iv := data[5:21]
|
||||
ret := data[21:macIndex]
|
||||
block, _ := aes.NewCipher(c.adminUID)
|
||||
stream := cipher.NewCTR(block, iv)
|
||||
stream.XORKeyStream(ret, ret)
|
||||
return ret, nil
|
||||
}
|
||||
|
|
@ -1,98 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
// for the ease of using json package
|
||||
type UserInfo struct {
|
||||
UID []byte
|
||||
// ALL of the following fields have to be accessed atomically
|
||||
SessionsCap uint32
|
||||
UpRate int64
|
||||
DownRate int64
|
||||
UpCredit int64
|
||||
DownCredit int64
|
||||
ExpiryTime int64
|
||||
}
|
||||
|
||||
type User struct {
|
||||
up *Userpanel
|
||||
|
||||
arrUID [16]byte
|
||||
|
||||
*UserInfo
|
||||
|
||||
valve *mux.Valve
|
||||
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[uint32]*mux.Session
|
||||
}
|
||||
|
||||
func MakeUser(up *Userpanel, uinfo *UserInfo) *User {
|
||||
// this instance of valve is shared across ALL sessions of a user
|
||||
valve := mux.MakeValve(uinfo.UpRate, uinfo.DownRate, &uinfo.UpCredit, &uinfo.DownCredit)
|
||||
u := &User{
|
||||
up: up,
|
||||
UserInfo: uinfo,
|
||||
valve: valve,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
}
|
||||
copy(u.arrUID[:], uinfo.UID)
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *User) addUpCredit(delta int64) { u.valve.AddRxCredit(delta) }
|
||||
func (u *User) addDownCredit(delta int64) { u.valve.AddTxCredit(delta) }
|
||||
func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.SessionsCap, cap) }
|
||||
func (u *User) setUpRate(rate int64) { u.valve.SetRxRate(rate) }
|
||||
func (u *User) setDownRate(rate int64) { u.valve.SetTxRate(rate) }
|
||||
func (u *User) setUpCredit(n int64) { u.valve.SetRxCredit(n) }
|
||||
func (u *User) setDownCredit(n int64) { u.valve.SetTxCredit(n) }
|
||||
func (u *User) setExpiryTime(time int64) { atomic.StoreInt64(&u.ExpiryTime, time) }
|
||||
|
||||
func (u *User) updateInfo(uinfo UserInfo) {
|
||||
u.setSessionsCap(uinfo.SessionsCap)
|
||||
u.setUpCredit(uinfo.UpCredit)
|
||||
u.setDownCredit(uinfo.DownCredit)
|
||||
u.setUpRate(uinfo.UpRate)
|
||||
u.setDownRate(uinfo.DownRate)
|
||||
u.setExpiryTime(uinfo.ExpiryTime)
|
||||
}
|
||||
|
||||
func (u *User) DelSession(sessionID uint32) {
|
||||
u.sessionsM.Lock()
|
||||
delete(u.sessions, sessionID)
|
||||
if len(u.sessions) == 0 {
|
||||
u.sessionsM.Unlock()
|
||||
u.up.delActiveUser(u.UID)
|
||||
return
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (u *User) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
|
||||
if time.Now().Unix() > u.ExpiryTime {
|
||||
return nil, false, errors.New("Expiry time passed")
|
||||
}
|
||||
u.sessionsM.Lock()
|
||||
if sesh = u.sessions[sessionID]; sesh != nil {
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, true, nil
|
||||
} else {
|
||||
if len(u.sessions) >= int(u.SessionsCap) {
|
||||
u.sessionsM.Unlock()
|
||||
return nil, false, errors.New("SessionsCap reached")
|
||||
}
|
||||
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
|
||||
u.sessions[sessionID] = sesh
|
||||
u.sessionsM.Unlock()
|
||||
return sesh, false, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,473 +0,0 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/boltdb/bolt"
|
||||
)
|
||||
|
||||
var Uint32 = binary.BigEndian.Uint32
|
||||
var Uint64 = binary.BigEndian.Uint64
|
||||
var PutUint16 = binary.BigEndian.PutUint16
|
||||
var PutUint32 = binary.BigEndian.PutUint32
|
||||
var PutUint64 = binary.BigEndian.PutUint64
|
||||
|
||||
type Userpanel struct {
|
||||
db *bolt.DB
|
||||
bakRoot string
|
||||
|
||||
activeUsersM sync.RWMutex
|
||||
activeUsers map[[16]byte]*User
|
||||
}
|
||||
|
||||
func MakeUserpanel(dbPath, bakRoot string) (*Userpanel, error) {
|
||||
db, err := bolt.Open(dbPath, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if bakRoot == "" {
|
||||
os.Mkdir("db-backup", 0777)
|
||||
bakRoot = "db-backup"
|
||||
}
|
||||
bakRoot = path.Clean(bakRoot)
|
||||
up := &Userpanel{
|
||||
db: db,
|
||||
bakRoot: bakRoot,
|
||||
activeUsers: make(map[[16]byte]*User),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Second * 10)
|
||||
up.updateCredits()
|
||||
}
|
||||
}()
|
||||
return up, nil
|
||||
}
|
||||
|
||||
// credits of all users are updated together so that there is only 1 goroutine managing it
|
||||
func (up *Userpanel) updateCredits() {
|
||||
up.activeUsersM.RLock()
|
||||
for _, u := range up.activeUsers {
|
||||
up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(u.arrUID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte("UpCredit"), i64ToB(u.valve.GetRxCredit())); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Put([]byte("DownCredit"), i64ToB(u.valve.GetTxCredit())); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
})
|
||||
}
|
||||
up.activeUsersM.RUnlock()
|
||||
|
||||
}
|
||||
|
||||
func (up *Userpanel) backupDB(bakFileName string) error {
|
||||
bakPath := up.bakRoot + "/" + bakFileName
|
||||
_, err := os.Stat(bakPath)
|
||||
if err == nil {
|
||||
return errors.New("Attempting to overwrite a file during backup!")
|
||||
}
|
||||
var bak *os.File
|
||||
if os.IsNotExist(err) {
|
||||
bak, err = os.Create(bakPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = up.db.View(func(tx *bolt.Tx) error {
|
||||
_, err := tx.WriteTo(bak)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var ErrUserNotFound = errors.New("User does not exist in db")
|
||||
var ErrUserNotActive = errors.New("User is not active")
|
||||
|
||||
func (up *Userpanel) GetAndActivateAdminUser(AdminUID []byte) (*User, error) {
|
||||
up.activeUsersM.Lock()
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], AdminUID)
|
||||
if user, ok := up.activeUsers[arrUID]; ok {
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
uinfo := UserInfo{
|
||||
UID: AdminUID,
|
||||
SessionsCap: 1e9,
|
||||
UpRate: 1e12,
|
||||
DownRate: 1e12,
|
||||
UpCredit: 1e15,
|
||||
DownCredit: 1e15,
|
||||
ExpiryTime: 1e15,
|
||||
}
|
||||
|
||||
user := MakeUser(up, &uinfo)
|
||||
up.activeUsers[arrUID] = user
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's info
|
||||
// from the db and mark it as an active user
|
||||
func (up *Userpanel) GetAndActivateUser(UID []byte) (*User, error) {
|
||||
up.activeUsersM.Lock()
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
if user, ok := up.activeUsers[arrUID]; ok {
|
||||
up.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) // reee brackets
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
up.activeUsersM.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
u := MakeUser(up, &uinfo)
|
||||
up.activeUsers[arrUID] = u
|
||||
up.activeUsersM.Unlock()
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryUint32(UID []byte, key string, value uint32) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte(key), u32ToB(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryInt64(UID []byte, key string, value int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if err := b.Put([]byte(key), i64ToB(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// This is used when all sessions of a user close
|
||||
func (up *Userpanel) delActiveUser(UID []byte) {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
up.activeUsersM.Lock()
|
||||
delete(up.activeUsers, arrUID)
|
||||
up.activeUsersM.Unlock()
|
||||
}
|
||||
|
||||
func (up *Userpanel) getActiveUser(UID []byte) *User {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
up.activeUsersM.RLock()
|
||||
ret := up.activeUsers[arrUID]
|
||||
up.activeUsersM.RUnlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
// below are remote control utilised functions
|
||||
|
||||
func (up *Userpanel) listActiveUsers() [][]byte {
|
||||
var ret [][]byte
|
||||
up.activeUsersM.RLock()
|
||||
for _, u := range up.activeUsers {
|
||||
ret = append(ret, u.UID)
|
||||
}
|
||||
up.activeUsersM.RUnlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (up *Userpanel) listAllUsers() []UserInfo {
|
||||
var ret []UserInfo
|
||||
up.db.View(func(tx *bolt.Tx) error {
|
||||
tx.ForEach(func(UID []byte, b *bolt.Bucket) error {
|
||||
// if we want to avoid writing every single key out,
|
||||
// we would have to either make UserInfo a map,
|
||||
// or use reflect.
|
||||
// neither is convinient
|
||||
var uinfo UserInfo
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
ret = append(ret, uinfo)
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
func (up *Userpanel) getUserInfo(UID []byte) (UserInfo, error) {
|
||||
var uinfo UserInfo
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
return uinfo, err
|
||||
}
|
||||
|
||||
// In boltdb, the value argument for bucket.Put has to be valid for the duration
|
||||
// of the transaction.
|
||||
// This basically means that you cannot reuse a byte slice for two different keys
|
||||
// in a transaction. So we need to allocate a fresh byte slice for each value
|
||||
func u32ToB(value uint32) []byte {
|
||||
quad := make([]byte, 4)
|
||||
PutUint32(quad, value)
|
||||
return quad
|
||||
}
|
||||
|
||||
func i64ToB(value int64) []byte {
|
||||
oct := make([]byte, 8)
|
||||
PutUint64(oct, uint64(value))
|
||||
return oct
|
||||
}
|
||||
|
||||
func (up *Userpanel) addNewUser(uinfo UserInfo) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket(uinfo.UID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("SessionsCap"), u32ToB(uinfo.SessionsCap)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = b.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) delUser(UID []byte) error {
|
||||
err := up.backupDB(strconv.FormatInt(time.Now().Unix(), 10) + "_pre_del_" + base64.StdEncoding.EncodeToString(UID) + ".bak")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = up.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.DeleteBucket(UID)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) syncMemFromDB(UID []byte) error {
|
||||
var uinfo UserInfo
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
uinfo.UID = UID
|
||||
uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap")))
|
||||
uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate"))))
|
||||
uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate"))))
|
||||
uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit"))))
|
||||
uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit"))))
|
||||
uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return ErrUserNotActive
|
||||
}
|
||||
u.updateInfo(uinfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// the following functions will update the db entries first, then if the
|
||||
// user is active, it will update it in memory.
|
||||
|
||||
func (up *Userpanel) setSessionsCap(UID []byte, cap uint32) error {
|
||||
err := up.updateDBEntryUint32(UID, "SessionsCap", cap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setSessionsCap(cap)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) setUpRate(UID []byte, rate int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "UpRate", rate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setUpRate(rate)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setDownRate(UID []byte, rate int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "DownRate", rate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setDownRate(rate)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setUpCredit(UID []byte, n int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "UpCredit", n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setUpCredit(n)
|
||||
return nil
|
||||
}
|
||||
func (up *Userpanel) setDownCredit(UID []byte, n int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "DownCredit", n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setDownCredit(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) setExpiryTime(UID []byte, time int64) error {
|
||||
err := up.updateDBEntryInt64(UID, "ExpiryTime", time)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.setExpiryTime(time)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) addUpCredit(UID []byte, delta int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
old := b.Get([]byte("UpCredit"))
|
||||
new := int64(Uint64(old)) + delta
|
||||
if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.addUpCredit(delta)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) addDownCredit(UID []byte, delta int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID)
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
old := b.Get([]byte("DownCredit"))
|
||||
new := int64(Uint64(old)) + delta
|
||||
if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u := up.getActiveUser(UID)
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
u.addDownCredit(delta)
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
)
|
||||
|
||||
type userPanel struct {
|
||||
manager UserManager
|
||||
|
||||
activeUsersM sync.RWMutex
|
||||
activeUsers map[[16]byte]*ActiveUser
|
||||
usageUpdateQueueM sync.Mutex
|
||||
usageUpdateQueue map[[16]byte]*usagePair
|
||||
}
|
||||
|
||||
func MakeUserPanel(manager UserManager) *userPanel {
|
||||
ret := &userPanel{
|
||||
manager: manager,
|
||||
activeUsers: make(map[[16]byte]*ActiveUser),
|
||||
usageUpdateQueue: make(map[[16]byte]*usagePair),
|
||||
}
|
||||
go ret.regularQueueUpload()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) {
|
||||
panel.activeUsersM.Lock()
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
if user, ok := panel.activeUsers[arrUID]; ok {
|
||||
panel.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
upRate, downRate, err := panel.manager.authenticateUser(UID)
|
||||
if err != nil {
|
||||
panel.activeUsersM.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
var upUsage, downUsage int64
|
||||
valve := mux.MakeValve(upRate, downRate, &upUsage, &downUsage)
|
||||
user := &ActiveUser{
|
||||
panel: panel,
|
||||
valve: valve,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
}
|
||||
copy(user.arrUID[:], UID)
|
||||
panel.activeUsers[user.arrUID] = user
|
||||
panel.activeUsersM.Unlock()
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (panel *userPanel) isActive(UID []byte) bool {
|
||||
var arrUID [16]byte
|
||||
copy(arrUID[:], UID)
|
||||
panel.activeUsersM.RLock()
|
||||
_, ok := panel.activeUsers[arrUID]
|
||||
panel.activeUsersM.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
type usagePair struct {
|
||||
up *int64
|
||||
down *int64
|
||||
}
|
||||
|
||||
func (panel *userPanel) updateUsageQueue() {
|
||||
panel.activeUsersM.Lock()
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
for _, user := range panel.activeUsers {
|
||||
upIncured, downIncured := user.valve.Nullify()
|
||||
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
|
||||
atomic.AddInt64(usage.up, upIncured)
|
||||
atomic.AddInt64(usage.down, downIncured)
|
||||
} else {
|
||||
// if the user hasn't been added to the queue
|
||||
usage = &usagePair{&upIncured, &downIncured}
|
||||
panel.usageUpdateQueue[user.arrUID] = usage
|
||||
}
|
||||
}
|
||||
panel.activeUsersM.Unlock()
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
}
|
||||
|
||||
func (panel *userPanel) updateUsageQueueForOne(user *ActiveUser) {
|
||||
// used when one particular user deactivates
|
||||
upIncured, downIncured := user.valve.Nullify()
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok {
|
||||
atomic.AddInt64(usage.up, upIncured)
|
||||
atomic.AddInt64(usage.down, downIncured)
|
||||
} else {
|
||||
usage = &usagePair{&upIncured, &downIncured}
|
||||
panel.usageUpdateQueue[user.arrUID] = usage
|
||||
}
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
|
||||
}
|
||||
|
||||
func (panel *userPanel) commitUpdate() {
|
||||
panel.usageUpdateQueueM.Lock()
|
||||
statuses := make([]statusUpdate, 0, len(panel.usageUpdateQueue))
|
||||
for arrUID, usage := range panel.usageUpdateQueue {
|
||||
panel.activeUsersM.RLock()
|
||||
user := panel.activeUsers[arrUID]
|
||||
panel.activeUsersM.RUnlock()
|
||||
var numSession int
|
||||
if user != nil {
|
||||
numSession = user.NumSession()
|
||||
}
|
||||
status := statusUpdate{
|
||||
UID: arrUID[:],
|
||||
active: panel.isActive(arrUID[:]),
|
||||
numSession: numSession,
|
||||
upUsage: *usage.up,
|
||||
downUsage: *usage.down,
|
||||
timestamp: time.Now().Unix(),
|
||||
}
|
||||
statuses = append(statuses, status)
|
||||
}
|
||||
panel.manager.uploadStatus(statuses)
|
||||
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
|
||||
panel.usageUpdateQueueM.Unlock()
|
||||
}
|
||||
|
||||
func (panel *userPanel) regularQueueUpload() {
|
||||
for {
|
||||
time.Sleep(1 * time.Minute)
|
||||
go func() {
|
||||
panel.updateUsageQueue()
|
||||
panel.commitUpdate()
|
||||
}()
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue