diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 7499f18..8b0d507 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "encoding/base64" "flag" "fmt" @@ -14,7 +13,6 @@ import ( mux "github.com/cbeuw/Cloak/internal/multiplex" "github.com/cbeuw/Cloak/internal/server" - "github.com/cbeuw/Cloak/internal/server/usermanager" "github.com/cbeuw/Cloak/internal/util" ) @@ -103,41 +101,38 @@ func dispatchConnection(conn net.Conn, sta *server.State) { return nil } - // adminUID can use the server as normal with unlimited QoS credits. The adminUID is not - // added to the userinfo database. The distinction between going into the admin mode - // and normal proxy mode is that sessionID needs == 0 for admin mode - if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 { - err = finishHandshake() - if err != nil { - log.Println(err) - return - } - c := sta.Userpanel.MakeController(sta.AdminUID) - for { - n, err := conn.Read(data) + /* + // adminUID can use the server as normal with unlimited QoS credits. The adminUID is not + // added to the userinfo database. The distinction between going into the admin mode + // and normal proxy mode is that sessionID needs == 0 for admin mode + if bytes.Equal(UID, sta.AdminUID) && sessionID == 0 { + err = finishHandshake() if err != nil { log.Println(err) return } - resp, err := c.HandleRequest(data[:n]) - if err != nil { - log.Println(err) - } - _, err = conn.Write(resp) - if err != nil { - log.Println(err) - return + c := sta.Userpanel.MakeController(sta.AdminUID) + for { + n, err := conn.Read(data) + if err != nil { + log.Println(err) + return + } + resp, err := c.HandleRequest(data[:n]) + if err != nil { + log.Println(err) + } + _, err = conn.Write(resp) + if err != nil { + log.Println(err) + return + } } + } + */ - } - - var user *usermanager.User - if bytes.Equal(UID, sta.AdminUID) { - user, err = sta.Userpanel.GetAndActivateAdminUser(UID) - } else { - user, err = sta.Userpanel.GetAndActivateUser(UID) - } + user, err := sta.Panel.GetUser(UID) if err != nil { log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID) goWeb(data) @@ -278,10 +273,6 @@ func main() { sta.ProxyBook["shadowsocks"] = ssLocalHost + ":" + ssLocalPort } - if sta.AdminUID == nil { - log.Fatalln("AdminUID cannot be empty!") - } - go sta.UsedRandomCleaner() listen := func(addr, port string) { diff --git a/internal/multiplex/qos.go b/internal/multiplex/qos.go index 9f70957..a22d584 100644 --- a/internal/multiplex/qos.go +++ b/internal/multiplex/qos.go @@ -17,14 +17,14 @@ type Valve struct { rxtb atomic.Value // *ratelimit.Bucket txtb atomic.Value // *ratelimit.Bucket - rxCredit *int64 - txCredit *int64 + rx *int64 + tx *int64 } -func MakeValve(rxRate, txRate int64, rxCredit, txCredit *int64) *Valve { +func MakeValve(rxRate, txRate int64, rx, tx *int64) *Valve { v := &Valve{ - rxCredit: rxCredit, - txCredit: txCredit, + rx: rx, + tx: tx, } v.SetRxRate(rxRate) v.SetTxRate(txRate) @@ -35,13 +35,12 @@ func (v *Valve) SetRxRate(rate int64) { v.rxtb.Store(ratelimit.NewBucketWithRate func (v *Valve) SetTxRate(rate int64) { v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate)) } func (v *Valve) rxWait(n int) { v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n)) } func (v *Valve) txWait(n int) { v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n)) } -func (v *Valve) SetRxCredit(n int64) { atomic.StoreInt64(v.rxCredit, n) } -func (v *Valve) SetTxCredit(n int64) { atomic.StoreInt64(v.txCredit, n) } -func (v *Valve) GetRxCredit() int64 { return atomic.LoadInt64(v.rxCredit) } -func (v *Valve) GetTxCredit() int64 { return atomic.LoadInt64(v.txCredit) } - -// n can be negative -func (v *Valve) AddRxCredit(n int64) int64 { return atomic.AddInt64(v.rxCredit, n) } - -// n can be negative -func (v *Valve) AddTxCredit(n int64) int64 { return atomic.AddInt64(v.txCredit, n) } +func (v *Valve) AddRx(n int64) { atomic.AddInt64(v.rx, n) } +func (v *Valve) AddTx(n int64) { atomic.AddInt64(v.tx, n) } +func (v *Valve) GetRx() int64 { return atomic.LoadInt64(v.rx) } +func (v *Valve) GetTx() int64 { return atomic.LoadInt64(v.tx) } +func (v *Valve) Nullify() (int64, int64) { + rx := atomic.SwapInt64(v.rx, 0) + tx := atomic.SwapInt64(v.tx, 0) + return rx, tx +} diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 0bb13a1..feb73fd 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -50,9 +50,6 @@ func makeSwitchboard(sesh *Session, valve *Valve) *switchboard { var errNilOptimum error = errors.New("The optimal connection is nil") -var ErrNoRxCredit error = errors.New("No Rx credit is left") -var ErrNoTxCredit error = errors.New("No Tx credit is left") - func (sb *switchboard) send(data []byte) (int, error) { ce := sb.getOptimum() if ce == nil { @@ -65,11 +62,7 @@ func (sb *switchboard) send(data []byte) (int, error) { return n, err } sb.txWait(n) - if sb.AddTxCredit(-int64(n)) < 0 { - log.Println(ErrNoTxCredit) - go sb.session.Close() - return n, ErrNoTxCredit - } + sb.Valve.AddTx(int64(n)) atomic.AddUint32(&ce.sendQueue, ^uint32(n-1)) go sb.updateOptimum() return n, nil @@ -133,17 +126,14 @@ func (sb *switchboard) deplex(ce *connEnclave) { for { n, err := sb.session.obfsedRead(ce.remoteConn, buf) sb.rxWait(n) + sb.Valve.AddRx(int64(n)) if err != nil { //log.Println(err) go ce.remoteConn.Close() sb.removeConn(ce) return } - if sb.AddRxCredit(-int64(n)) < 0 { - log.Println(ErrNoRxCredit) - sb.session.Close() - return - } + frame, err := sb.session.deobfs(buf[:n]) if err != nil { log.Println(err) diff --git a/internal/server/activeuser.go b/internal/server/activeuser.go new file mode 100644 index 0000000..def2c3c --- /dev/null +++ b/internal/server/activeuser.go @@ -0,0 +1,67 @@ +package server + +import ( + "net" + "sync" + + mux "github.com/cbeuw/Cloak/internal/multiplex" +) + +type ActiveUser struct { + panel *userPanel + + arrUID [16]byte + + valve *mux.Valve + + sessionsM sync.RWMutex + sessions map[uint32]*mux.Session +} + +func (u *ActiveUser) DelSession(sessionID uint32) { + u.sessionsM.Lock() + delete(u.sessions, sessionID) + if len(u.sessions) == 0 { + u.panel.updateUsageQueueForOne(u) + u.panel.activeUsersM.Lock() + delete(u.panel.activeUsers, u.arrUID) + u.panel.activeUsersM.Unlock() + } + u.sessionsM.Unlock() +} + +func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { + u.sessionsM.Lock() + if sesh = u.sessions[sessionID]; sesh != nil { + u.sessionsM.Unlock() + return sesh, true, nil + } else { + err := u.panel.manager.authoriseNewSession(u) + if err != nil { + u.sessionsM.Unlock() + return nil, false, err + } + sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) + u.sessions[sessionID] = sesh + u.sessionsM.Unlock() + return sesh, false, nil + } +} + +func (u *ActiveUser) Terminate() { + u.sessionsM.Lock() + for _, sesh := range u.sessions { + go sesh.Close() + } + u.sessionsM.Unlock() + u.panel.activeUsersM.Lock() + delete(u.panel.activeUsers, u.arrUID) + u.panel.activeUsersM.Unlock() +} + +func (u *ActiveUser) NumSession() int { + u.sessionsM.RLock() + l := len(u.sessions) + u.sessionsM.RUnlock() + return l +} diff --git a/internal/server/state.go b/internal/server/state.go index d4f4178..5ca0920 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -8,8 +8,6 @@ import ( "io/ioutil" "sync" "time" - - "github.com/cbeuw/Cloak/internal/server/usermanager" ) type rawConfig struct { @@ -19,6 +17,7 @@ type rawConfig struct { AdminUID string DatabasePath string BackupDirPath string + CncMode bool } // State type stores the global state of the program @@ -28,14 +27,16 @@ type State struct { BindHost string BindPort string - Now func() time.Time - AdminUID []byte - staticPv crypto.PrivateKey - Userpanel *usermanager.Userpanel + Now func() time.Time + AdminUID []byte + staticPv crypto.PrivateKey + + RedirAddr string + usedRandomM sync.RWMutex usedRandom map[[32]byte]int - RedirAddr string + Panel *userPanel } func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, error) { @@ -45,6 +46,7 @@ func InitState(bindHost, bindPort string, nowFunc func() time.Time) (*State, err Now: nowFunc, } ret.usedRandom = make(map[[32]byte]int) + go ret.UsedRandomCleaner() return ret, nil } @@ -66,11 +68,15 @@ func (sta *State) ParseConfig(conf string) (err error) { } } - up, err := usermanager.MakeUserpanel(preParse.DatabasePath, preParse.BackupDirPath) - if err != nil { - return errors.New("Attempting to open database: " + err.Error()) + if preParse.CncMode { + + } else { + manager, err := MakeLocalManager(preParse.DatabasePath) + if err != nil { + return err + } + sta.Panel = MakeUserPanel(manager) } - sta.Userpanel = up sta.RedirAddr = preParse.RedirAddr sta.ProxyBook = preParse.ProxyBook diff --git a/internal/server/um_local.go b/internal/server/um_local.go new file mode 100644 index 0000000..59785f3 --- /dev/null +++ b/internal/server/um_local.go @@ -0,0 +1,115 @@ +package server + +import ( + "encoding/binary" + "log" + "time" + + "github.com/boltdb/bolt" +) + +var Uint32 = binary.BigEndian.Uint32 +var Uint64 = binary.BigEndian.Uint64 +var PutUint32 = binary.BigEndian.PutUint32 +var PutUint64 = binary.BigEndian.PutUint64 + +type localManager struct { + db *bolt.DB +} + +func MakeLocalManager(dbPath string) (*localManager, error) { + db, err := bolt.Open(dbPath, 0600, nil) + if err != nil { + return nil, err + } + return &localManager{db}, nil +} + +func (manager *localManager) authenticateUser(UID []byte) (int64, int64, error) { + var upRate, downRate, upCredit, downCredit, expiryTime int64 + err := manager.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(UID) + if bucket == nil { + return ErrUserNotFound + } + upRate = int64(Uint64(bucket.Get([]byte("UpRate")))) + downRate = int64(Uint64(bucket.Get([]byte("DownRate")))) + upCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) + downCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) + expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) + return nil + }) + if err != nil { + return 0, 0, err + } + if upCredit <= 0 { + return 0, 0, ErrNoUpCredit + } + if downCredit <= 0 { + return 0, 0, ErrNoDownCredit + } + if expiryTime < time.Now().Unix() { + return 0, 0, ErrUserExpired + } + + return upRate, downRate, nil +} + +func (manager *localManager) authoriseNewSession(user *ActiveUser) error { + var sessionsCap int + var upCredit, downCredit, expiryTime int64 + err := manager.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(user.arrUID[:]) + if bucket == nil { + return ErrUserNotFound + } + sessionsCap = int(Uint32(bucket.Get([]byte("SessionsCap")))) + upCredit = int64(Uint64(bucket.Get([]byte("UpCredit")))) + downCredit = int64(Uint64(bucket.Get([]byte("DownCredit")))) + expiryTime = int64(Uint64(bucket.Get([]byte("ExpiryTime")))) + return nil + }) + if err != nil { + return err + } + if upCredit <= 0 { + return ErrNoUpCredit + } + if downCredit <= 0 { + return ErrNoDownCredit + } + if expiryTime < time.Now().Unix() { + return ErrUserExpired + } + //user.sessionsM.RLock() + if len(user.sessions) >= sessionsCap { + //user.sessionsM.RUnlock() + return ErrSessionsCapReached + } + //user.sessionsM.RUnlock() + return nil +} + +func i64ToB(value int64) []byte { + oct := make([]byte, 8) + PutUint64(oct, uint64(value)) + return oct +} + +func (manager *localManager) uploadStatus(uploads []statusUpdate) error { + err := manager.db.Update(func(tx *bolt.Tx) error { + for _, status := range uploads { + bucket := tx.Bucket(status.UID) + if bucket == nil { + log.Printf("%x doesn't exist\n", status.UID) + continue + } + oldUp := int64(Uint64(bucket.Get([]byte("UpCredit")))) + bucket.Put([]byte("UpCredit"), i64ToB(oldUp-status.upUsage)) + oldDown := int64(Uint64(bucket.Get([]byte("DownCredit")))) + bucket.Put([]byte("DownCredit"), i64ToB(oldDown-status.downUsage)) + } + return nil + }) + return err +} diff --git a/internal/server/usermanager.go b/internal/server/usermanager.go new file mode 100644 index 0000000..f62a0b1 --- /dev/null +++ b/internal/server/usermanager.go @@ -0,0 +1,28 @@ +package server + +import ( + "errors" +) + +type statusUpdate struct { + UID []byte + active bool + numSession int + + upUsage int64 + downUsage int64 + timestamp int64 +} + +var ErrUserNotFound = errors.New("UID does not correspond to a user") +var ErrSessionsCapReached = errors.New("Sessions cap has reached") +var ErrNoUpCredit = errors.New("No upload credit left") +var ErrNoDownCredit = errors.New("No download credit left") +var ErrUserExpired = errors.New("User has expired") + +type UserManager interface { + authenticateUser([]byte) (int64, int64, error) + authoriseNewSession(*ActiveUser) error + // TODO: fetch update's response + uploadStatus([]statusUpdate) error +} diff --git a/internal/server/usermanager/controller.go b/internal/server/usermanager/controller.go deleted file mode 100644 index f62e82e..0000000 --- a/internal/server/usermanager/controller.go +++ /dev/null @@ -1,212 +0,0 @@ -package usermanager - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/json" - "errors" - "log" -) - -// FIXME: sanity checks. The server may panic due to user input - -// TODO: manual backup - -/* -0 reserved -1 listActiveUsers none []uids -2 listAllUsers none []userinfo -3 getUserInfo uid userinfo - -4 addNewUser userinfo ok -5 delUser uid ok -6 syncMemFromDB uid ok - -7 setSessionsCap uid cap ok -8 setUpRate uid rate ok -9 setDownRate uid rate ok -10 setUpCredit uid credit ok -11 setDownCredit uid credit ok -12 setExpiryTime uid time ok -13 addUpCredit uid delta ok -14 addDownCredit uid delta ok -*/ - -type controller struct { - *Userpanel - adminUID []byte -} - -func (up *Userpanel) MakeController(adminUID []byte) *controller { - return &controller{up, adminUID} -} - -var errInvalidArgument = errors.New("Invalid argument format") - -func (c *controller) HandleRequest(req []byte) (resp []byte, err error) { - check := func(err error) []byte { - if err != nil { - return c.respond([]byte(err.Error())) - } else { - return c.respond([]byte("ok")) - } - } - plain, err := c.checkAndDecrypt(req) - if err == ErrInvalidMac { - log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\naUID:%x\nraw request:\n%x\ndecrypted msg:\n%x", c.adminUID, req, plain) - return nil, err - } else if err != nil { - log.Printf("aUID:%x\n,err:%v\n", c.adminUID, err) - return c.respond([]byte(err.Error())), nil - } - - typ := plain[0] - var arg []byte - if len(plain) > 1 { - arg = plain[1:] - } - switch typ { - case 1: - UIDs := c.listActiveUsers() - resp, _ = json.Marshal(UIDs) - resp = c.respond(resp) - case 2: - uinfos := c.listAllUsers() - resp, _ = json.Marshal(uinfos) - resp = c.respond(resp) - case 3: - uinfo, err := c.getUserInfo(arg) - if err != nil { - resp = c.respond([]byte(err.Error())) - break - } - resp, _ = json.Marshal(uinfo) - resp = c.respond(resp) - case 4: - var uinfo UserInfo - err = json.Unmarshal(arg, &uinfo) - if err != nil { - resp = c.respond([]byte(err.Error())) - break - } - - err = c.addNewUser(uinfo) - resp = check(err) - case 5: - err = c.delUser(arg) - resp = check(err) - case 6: - err = c.syncMemFromDB(arg) - resp = check(err) - case 7: - if len(arg) < 20 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setSessionsCap(arg[0:16], Uint32(arg[16:20])) - resp = check(err) - case 8: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setUpRate(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 9: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setDownRate(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 10: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setUpCredit(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 11: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setDownCredit(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 12: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.setExpiryTime(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 13: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.addUpCredit(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - case 14: - if len(arg) < 24 { - resp = c.respond([]byte(errInvalidArgument.Error())) - break - } - err = c.addDownCredit(arg[0:16], int64(Uint64(arg[16:24]))) - resp = check(err) - default: - return c.respond([]byte("Unsupported action")), nil - - } - return - -} - -var ErrInvalidMac = errors.New("Mac mismatch") -var errMsgTooShort = errors.New("Message length is less than 54") - -// protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes] -func (c *controller) respond(resp []byte) []byte { - respLen := len(resp) - - buf := make([]byte, 5+16+respLen+32) - buf[0] = 0x17 - buf[1] = 0x03 - buf[2] = 0x03 - PutUint16(buf[3:5], uint16(16+respLen+32)) - - rand.Read(buf[5:21]) //iv - copy(buf[21:], resp) - block, _ := aes.NewCipher(c.adminUID) - stream := cipher.NewCTR(block, buf[5:21]) - stream.XORKeyStream(buf[21:21+respLen], buf[21:21+respLen]) - - mac := hmac.New(sha256.New, c.adminUID) - mac.Write(buf[5 : 21+respLen]) - copy(buf[21+respLen:], mac.Sum(nil)) - - return buf -} - -func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) { - if len(data) < 54 { - return nil, errMsgTooShort - } - macIndex := len(data) - 32 - mac := hmac.New(sha256.New, c.adminUID) - mac.Write(data[5:macIndex]) - expected := mac.Sum(nil) - if !hmac.Equal(data[macIndex:], expected) { - return nil, ErrInvalidMac - } - - iv := data[5:21] - ret := data[21:macIndex] - block, _ := aes.NewCipher(c.adminUID) - stream := cipher.NewCTR(block, iv) - stream.XORKeyStream(ret, ret) - return ret, nil -} diff --git a/internal/server/usermanager/user.go b/internal/server/usermanager/user.go deleted file mode 100644 index e4bc020..0000000 --- a/internal/server/usermanager/user.go +++ /dev/null @@ -1,98 +0,0 @@ -package usermanager - -import ( - "errors" - "net" - "sync" - "sync/atomic" - "time" - - mux "github.com/cbeuw/Cloak/internal/multiplex" -) - -// for the ease of using json package -type UserInfo struct { - UID []byte - // ALL of the following fields have to be accessed atomically - SessionsCap uint32 - UpRate int64 - DownRate int64 - UpCredit int64 - DownCredit int64 - ExpiryTime int64 -} - -type User struct { - up *Userpanel - - arrUID [16]byte - - *UserInfo - - valve *mux.Valve - - sessionsM sync.RWMutex - sessions map[uint32]*mux.Session -} - -func MakeUser(up *Userpanel, uinfo *UserInfo) *User { - // this instance of valve is shared across ALL sessions of a user - valve := mux.MakeValve(uinfo.UpRate, uinfo.DownRate, &uinfo.UpCredit, &uinfo.DownCredit) - u := &User{ - up: up, - UserInfo: uinfo, - valve: valve, - sessions: make(map[uint32]*mux.Session), - } - copy(u.arrUID[:], uinfo.UID) - return u -} - -func (u *User) addUpCredit(delta int64) { u.valve.AddRxCredit(delta) } -func (u *User) addDownCredit(delta int64) { u.valve.AddTxCredit(delta) } -func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.SessionsCap, cap) } -func (u *User) setUpRate(rate int64) { u.valve.SetRxRate(rate) } -func (u *User) setDownRate(rate int64) { u.valve.SetTxRate(rate) } -func (u *User) setUpCredit(n int64) { u.valve.SetRxCredit(n) } -func (u *User) setDownCredit(n int64) { u.valve.SetTxCredit(n) } -func (u *User) setExpiryTime(time int64) { atomic.StoreInt64(&u.ExpiryTime, time) } - -func (u *User) updateInfo(uinfo UserInfo) { - u.setSessionsCap(uinfo.SessionsCap) - u.setUpCredit(uinfo.UpCredit) - u.setDownCredit(uinfo.DownCredit) - u.setUpRate(uinfo.UpRate) - u.setDownRate(uinfo.DownRate) - u.setExpiryTime(uinfo.ExpiryTime) -} - -func (u *User) DelSession(sessionID uint32) { - u.sessionsM.Lock() - delete(u.sessions, sessionID) - if len(u.sessions) == 0 { - u.sessionsM.Unlock() - u.up.delActiveUser(u.UID) - return - } - u.sessionsM.Unlock() -} - -func (u *User) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { - if time.Now().Unix() > u.ExpiryTime { - return nil, false, errors.New("Expiry time passed") - } - u.sessionsM.Lock() - if sesh = u.sessions[sessionID]; sesh != nil { - u.sessionsM.Unlock() - return sesh, true, nil - } else { - if len(u.sessions) >= int(u.SessionsCap) { - u.sessionsM.Unlock() - return nil, false, errors.New("SessionsCap reached") - } - sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) - u.sessions[sessionID] = sesh - u.sessionsM.Unlock() - return sesh, false, nil - } -} diff --git a/internal/server/usermanager/userpanel.go b/internal/server/usermanager/userpanel.go deleted file mode 100644 index 64e2005..0000000 --- a/internal/server/usermanager/userpanel.go +++ /dev/null @@ -1,473 +0,0 @@ -package usermanager - -import ( - "encoding/base64" - "encoding/binary" - "errors" - "os" - "path" - "strconv" - "sync" - "time" - - "github.com/boltdb/bolt" -) - -var Uint32 = binary.BigEndian.Uint32 -var Uint64 = binary.BigEndian.Uint64 -var PutUint16 = binary.BigEndian.PutUint16 -var PutUint32 = binary.BigEndian.PutUint32 -var PutUint64 = binary.BigEndian.PutUint64 - -type Userpanel struct { - db *bolt.DB - bakRoot string - - activeUsersM sync.RWMutex - activeUsers map[[16]byte]*User -} - -func MakeUserpanel(dbPath, bakRoot string) (*Userpanel, error) { - db, err := bolt.Open(dbPath, 0600, nil) - if err != nil { - return nil, err - } - if bakRoot == "" { - os.Mkdir("db-backup", 0777) - bakRoot = "db-backup" - } - bakRoot = path.Clean(bakRoot) - up := &Userpanel{ - db: db, - bakRoot: bakRoot, - activeUsers: make(map[[16]byte]*User), - } - go func() { - for { - time.Sleep(time.Second * 10) - up.updateCredits() - } - }() - return up, nil -} - -// credits of all users are updated together so that there is only 1 goroutine managing it -func (up *Userpanel) updateCredits() { - up.activeUsersM.RLock() - for _, u := range up.activeUsers { - up.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(u.arrUID[:]) - if b == nil { - return ErrUserNotFound - } - if err := b.Put([]byte("UpCredit"), i64ToB(u.valve.GetRxCredit())); err != nil { - return err - } - if err := b.Put([]byte("DownCredit"), i64ToB(u.valve.GetTxCredit())); err != nil { - return err - } - return nil - - }) - } - up.activeUsersM.RUnlock() - -} - -func (up *Userpanel) backupDB(bakFileName string) error { - bakPath := up.bakRoot + "/" + bakFileName - _, err := os.Stat(bakPath) - if err == nil { - return errors.New("Attempting to overwrite a file during backup!") - } - var bak *os.File - if os.IsNotExist(err) { - bak, err = os.Create(bakPath) - if err != nil { - return err - } - } - err = up.db.View(func(tx *bolt.Tx) error { - _, err := tx.WriteTo(bak) - if err != nil { - return err - } - return nil - }) - return err -} - -var ErrUserNotFound = errors.New("User does not exist in db") -var ErrUserNotActive = errors.New("User is not active") - -func (up *Userpanel) GetAndActivateAdminUser(AdminUID []byte) (*User, error) { - up.activeUsersM.Lock() - var arrUID [16]byte - copy(arrUID[:], AdminUID) - if user, ok := up.activeUsers[arrUID]; ok { - up.activeUsersM.Unlock() - return user, nil - } - - uinfo := UserInfo{ - UID: AdminUID, - SessionsCap: 1e9, - UpRate: 1e12, - DownRate: 1e12, - UpCredit: 1e15, - DownCredit: 1e15, - ExpiryTime: 1e15, - } - - user := MakeUser(up, &uinfo) - up.activeUsers[arrUID] = user - up.activeUsersM.Unlock() - return user, nil -} - -// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's info -// from the db and mark it as an active user -func (up *Userpanel) GetAndActivateUser(UID []byte) (*User, error) { - up.activeUsersM.Lock() - var arrUID [16]byte - copy(arrUID[:], UID) - if user, ok := up.activeUsers[arrUID]; ok { - up.activeUsersM.Unlock() - return user, nil - } - - var uinfo UserInfo - uinfo.UID = UID - err := up.db.View(func(tx *bolt.Tx) error { - b := tx.Bucket(UID[:]) - if b == nil { - return ErrUserNotFound - } - uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap"))) - uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) // reee brackets - uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime")))) - return nil - }) - if err != nil { - up.activeUsersM.Unlock() - return nil, err - } - u := MakeUser(up, &uinfo) - up.activeUsers[arrUID] = u - up.activeUsersM.Unlock() - return u, nil -} - -func (up *Userpanel) updateDBEntryUint32(UID []byte, key string, value uint32) error { - err := up.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - if err := b.Put([]byte(key), u32ToB(value)); err != nil { - return err - } - return nil - }) - return err -} - -func (up *Userpanel) updateDBEntryInt64(UID []byte, key string, value int64) error { - err := up.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - if err := b.Put([]byte(key), i64ToB(value)); err != nil { - return err - } - return nil - }) - return err -} - -// This is used when all sessions of a user close -func (up *Userpanel) delActiveUser(UID []byte) { - var arrUID [16]byte - copy(arrUID[:], UID) - up.activeUsersM.Lock() - delete(up.activeUsers, arrUID) - up.activeUsersM.Unlock() -} - -func (up *Userpanel) getActiveUser(UID []byte) *User { - var arrUID [16]byte - copy(arrUID[:], UID) - up.activeUsersM.RLock() - ret := up.activeUsers[arrUID] - up.activeUsersM.RUnlock() - return ret -} - -// below are remote control utilised functions - -func (up *Userpanel) listActiveUsers() [][]byte { - var ret [][]byte - up.activeUsersM.RLock() - for _, u := range up.activeUsers { - ret = append(ret, u.UID) - } - up.activeUsersM.RUnlock() - return ret -} - -func (up *Userpanel) listAllUsers() []UserInfo { - var ret []UserInfo - up.db.View(func(tx *bolt.Tx) error { - tx.ForEach(func(UID []byte, b *bolt.Bucket) error { - // if we want to avoid writing every single key out, - // we would have to either make UserInfo a map, - // or use reflect. - // neither is convinient - var uinfo UserInfo - uinfo.UID = UID - uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap"))) - uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime")))) - ret = append(ret, uinfo) - return nil - }) - return nil - }) - return ret -} - -func (up *Userpanel) getUserInfo(UID []byte) (UserInfo, error) { - var uinfo UserInfo - err := up.db.View(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - uinfo.UID = UID - uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap"))) - uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime")))) - return nil - }) - return uinfo, err -} - -// In boltdb, the value argument for bucket.Put has to be valid for the duration -// of the transaction. -// This basically means that you cannot reuse a byte slice for two different keys -// in a transaction. So we need to allocate a fresh byte slice for each value -func u32ToB(value uint32) []byte { - quad := make([]byte, 4) - PutUint32(quad, value) - return quad -} - -func i64ToB(value int64) []byte { - oct := make([]byte, 8) - PutUint64(oct, uint64(value)) - return oct -} - -func (up *Userpanel) addNewUser(uinfo UserInfo) error { - err := up.db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket(uinfo.UID[:]) - if err != nil { - return err - } - if err = b.Put([]byte("SessionsCap"), u32ToB(uinfo.SessionsCap)); err != nil { - return err - } - if err = b.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { - return err - } - if err = b.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { - return err - } - if err = b.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { - return err - } - if err = b.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { - return err - } - if err = b.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { - return err - } - return nil - }) - return err -} - -func (up *Userpanel) delUser(UID []byte) error { - err := up.backupDB(strconv.FormatInt(time.Now().Unix(), 10) + "_pre_del_" + base64.StdEncoding.EncodeToString(UID) + ".bak") - if err != nil { - return err - } - err = up.db.Update(func(tx *bolt.Tx) error { - return tx.DeleteBucket(UID) - }) - return err -} - -func (up *Userpanel) syncMemFromDB(UID []byte) error { - var uinfo UserInfo - err := up.db.View(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - uinfo.UID = UID - uinfo.SessionsCap = Uint32(b.Get([]byte("SessionsCap"))) - uinfo.UpRate = int64(Uint64(b.Get([]byte("UpRate")))) - uinfo.DownRate = int64(Uint64(b.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(Uint64(b.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(Uint64(b.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(Uint64(b.Get([]byte("ExpiryTime")))) - return nil - }) - if err != nil { - return err - } - - u := up.getActiveUser(UID) - if u == nil { - return ErrUserNotActive - } - u.updateInfo(uinfo) - return nil -} - -// the following functions will update the db entries first, then if the -// user is active, it will update it in memory. - -func (up *Userpanel) setSessionsCap(UID []byte, cap uint32) error { - err := up.updateDBEntryUint32(UID, "SessionsCap", cap) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setSessionsCap(cap) - return nil -} - -func (up *Userpanel) setUpRate(UID []byte, rate int64) error { - err := up.updateDBEntryInt64(UID, "UpRate", rate) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setUpRate(rate) - return nil -} -func (up *Userpanel) setDownRate(UID []byte, rate int64) error { - err := up.updateDBEntryInt64(UID, "DownRate", rate) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setDownRate(rate) - return nil -} -func (up *Userpanel) setUpCredit(UID []byte, n int64) error { - err := up.updateDBEntryInt64(UID, "UpCredit", n) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setUpCredit(n) - return nil -} -func (up *Userpanel) setDownCredit(UID []byte, n int64) error { - err := up.updateDBEntryInt64(UID, "DownCredit", n) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setDownCredit(n) - return nil -} - -func (up *Userpanel) setExpiryTime(UID []byte, time int64) error { - err := up.updateDBEntryInt64(UID, "ExpiryTime", time) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.setExpiryTime(time) - return nil -} - -func (up *Userpanel) addUpCredit(UID []byte, delta int64) error { - err := up.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - old := b.Get([]byte("UpCredit")) - new := int64(Uint64(old)) + delta - if err := b.Put([]byte("UpCredit"), i64ToB(new)); err != nil { - return err - } - return nil - }) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.addUpCredit(delta) - return nil -} - -func (up *Userpanel) addDownCredit(UID []byte, delta int64) error { - err := up.db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket(UID) - if b == nil { - return ErrUserNotFound - } - old := b.Get([]byte("DownCredit")) - new := int64(Uint64(old)) + delta - if err := b.Put([]byte("DownCredit"), i64ToB(new)); err != nil { - return err - } - return nil - }) - if err != nil { - return err - } - u := up.getActiveUser(UID) - if u == nil { - return nil - } - u.addDownCredit(delta) - return nil -} diff --git a/internal/server/userpanel.go b/internal/server/userpanel.go new file mode 100644 index 0000000..d4c888e --- /dev/null +++ b/internal/server/userpanel.go @@ -0,0 +1,138 @@ +package server + +import ( + "sync" + "sync/atomic" + "time" + + mux "github.com/cbeuw/Cloak/internal/multiplex" +) + +type userPanel struct { + manager UserManager + + activeUsersM sync.RWMutex + activeUsers map[[16]byte]*ActiveUser + usageUpdateQueueM sync.Mutex + usageUpdateQueue map[[16]byte]*usagePair +} + +func MakeUserPanel(manager UserManager) *userPanel { + ret := &userPanel{ + manager: manager, + activeUsers: make(map[[16]byte]*ActiveUser), + usageUpdateQueue: make(map[[16]byte]*usagePair), + } + go ret.regularQueueUpload() + return ret +} + +func (panel *userPanel) GetUser(UID []byte) (*ActiveUser, error) { + panel.activeUsersM.Lock() + var arrUID [16]byte + copy(arrUID[:], UID) + if user, ok := panel.activeUsers[arrUID]; ok { + panel.activeUsersM.Unlock() + return user, nil + } + + upRate, downRate, err := panel.manager.authenticateUser(UID) + if err != nil { + panel.activeUsersM.Unlock() + return nil, err + } + var upUsage, downUsage int64 + valve := mux.MakeValve(upRate, downRate, &upUsage, &downUsage) + user := &ActiveUser{ + panel: panel, + valve: valve, + sessions: make(map[uint32]*mux.Session), + } + copy(user.arrUID[:], UID) + panel.activeUsers[user.arrUID] = user + panel.activeUsersM.Unlock() + return user, nil +} + +func (panel *userPanel) isActive(UID []byte) bool { + var arrUID [16]byte + copy(arrUID[:], UID) + panel.activeUsersM.RLock() + _, ok := panel.activeUsers[arrUID] + panel.activeUsersM.RUnlock() + return ok +} + +type usagePair struct { + up *int64 + down *int64 +} + +func (panel *userPanel) updateUsageQueue() { + panel.activeUsersM.Lock() + panel.usageUpdateQueueM.Lock() + for _, user := range panel.activeUsers { + upIncured, downIncured := user.valve.Nullify() + if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok { + atomic.AddInt64(usage.up, upIncured) + atomic.AddInt64(usage.down, downIncured) + } else { + // if the user hasn't been added to the queue + usage = &usagePair{&upIncured, &downIncured} + panel.usageUpdateQueue[user.arrUID] = usage + } + } + panel.activeUsersM.Unlock() + panel.usageUpdateQueueM.Unlock() +} + +func (panel *userPanel) updateUsageQueueForOne(user *ActiveUser) { + // used when one particular user deactivates + upIncured, downIncured := user.valve.Nullify() + panel.usageUpdateQueueM.Lock() + if usage, ok := panel.usageUpdateQueue[user.arrUID]; ok { + atomic.AddInt64(usage.up, upIncured) + atomic.AddInt64(usage.down, downIncured) + } else { + usage = &usagePair{&upIncured, &downIncured} + panel.usageUpdateQueue[user.arrUID] = usage + } + panel.usageUpdateQueueM.Unlock() + +} + +func (panel *userPanel) commitUpdate() { + panel.usageUpdateQueueM.Lock() + statuses := make([]statusUpdate, 0, len(panel.usageUpdateQueue)) + for arrUID, usage := range panel.usageUpdateQueue { + panel.activeUsersM.RLock() + user := panel.activeUsers[arrUID] + panel.activeUsersM.RUnlock() + var numSession int + if user != nil { + numSession = user.NumSession() + } + status := statusUpdate{ + UID: arrUID[:], + active: panel.isActive(arrUID[:]), + numSession: numSession, + upUsage: *usage.up, + downUsage: *usage.down, + timestamp: time.Now().Unix(), + } + statuses = append(statuses, status) + } + panel.manager.uploadStatus(statuses) + panel.usageUpdateQueue = make(map[[16]byte]*usagePair) + panel.usageUpdateQueueM.Unlock() +} + +func (panel *userPanel) regularQueueUpload() { + for { + time.Sleep(1 * time.Minute) + go func() { + panel.updateUsageQueue() + panel.commitUpdate() + }() + } +}