mirror of https://github.com/cbeuw/Cloak
QOS and user managing, bug fixes
This commit is contained in:
parent
6a6b293164
commit
3534d05055
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
|
@ -60,7 +61,7 @@ func makeRemoteConn(sta *client.State) (net.Conn, error) {
|
|||
// Three discarded messages: ServerHello, ChangeCipherSpec and Finished
|
||||
discardBuf := make([]byte, 1024)
|
||||
for c := 0; c < 3; c++ {
|
||||
_, err = util.ReadTillDrain(remoteConn, discardBuf)
|
||||
_, err = util.ReadTLS(remoteConn, discardBuf)
|
||||
if err != nil {
|
||||
log.Printf("Reading discarded message %v: %v\n", c, err)
|
||||
return nil, err
|
||||
|
|
@ -122,9 +123,13 @@ func main() {
|
|||
log.Printf("Starting standalone mode. Listening for ss on %v:%v\n", localHost, localPort)
|
||||
}
|
||||
|
||||
opaque := time.Now().UnixNano()
|
||||
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
|
||||
// sessionID is limited to its UID.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
sessionID := rand.Uint32()
|
||||
|
||||
// opaque is used to generate the padding of session ticket
|
||||
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, opaque)
|
||||
sta := client.InitState(localHost, localPort, remoteHost, remotePort, time.Now, sessionID)
|
||||
err := sta.ParseConfig(pluginOpts)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
@ -140,19 +145,19 @@ func main() {
|
|||
log.Fatal("TicketTimeHint cannot be empty or 0")
|
||||
}
|
||||
|
||||
obfs := util.MakeObfs(sta.SID)
|
||||
deobfs := util.MakeDeobfs(sta.SID)
|
||||
sesh := mux.MakeSession(0, 1e9, 1e9, obfs, deobfs, util.ReadTillDrain)
|
||||
valve := mux.MakeValve(1e9, 1e9, 1e9, 1e9)
|
||||
obfs := util.MakeObfs(sta.UID)
|
||||
deobfs := util.MakeDeobfs(sta.UID)
|
||||
sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS)
|
||||
|
||||
// TODO: use sync group
|
||||
for i := 0; i < sta.NumConn; i++ {
|
||||
go func() {
|
||||
conn, err := makeRemoteConn(sta)
|
||||
if err != nil {
|
||||
log.Printf("Failed to establish new connections to remote: %v\n", err)
|
||||
return
|
||||
}
|
||||
sesh.AddConnection(conn)
|
||||
}()
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
|
||||
|
|
@ -175,8 +180,12 @@ func main() {
|
|||
stream, err := sesh.OpenStream()
|
||||
if err != nil {
|
||||
ssConn.Close()
|
||||
return
|
||||
}
|
||||
_, err = stream.Write(data[:i])
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
stream.Write(data[:i])
|
||||
go pipe(ssConn, stream)
|
||||
pipe(stream, ssConn)
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
//"net/http"
|
||||
//_ "net/http/pprof"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
//"runtime"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -70,14 +71,21 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
return
|
||||
}
|
||||
|
||||
isSS, SID := server.TouchStone(ch, sta)
|
||||
isSS, UID, sessionID := server.TouchStone(ch, sta)
|
||||
if !isSS {
|
||||
log.Printf("+1 non SS TLS traffic from %v\n", conn.RemoteAddr())
|
||||
goWeb(data)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: verify SID
|
||||
var arrUID [32]byte
|
||||
copy(arrUID[:], UID)
|
||||
user, err := sta.Userpanel.GetAndActivateUser(arrUID)
|
||||
log.Printf("UID: %x\n", UID)
|
||||
if err != nil {
|
||||
log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID)
|
||||
goWeb(data)
|
||||
}
|
||||
|
||||
reply := server.ComposeReply(ch)
|
||||
_, err = conn.Write(reply)
|
||||
|
|
@ -90,7 +98,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
// Two discarded messages: ChangeCipherSpec and Finished
|
||||
discardBuf := make([]byte, 1024)
|
||||
for c := 0; c < 2; c++ {
|
||||
_, err = util.ReadTillDrain(conn, discardBuf)
|
||||
_, err = util.ReadTLS(conn, discardBuf)
|
||||
if err != nil {
|
||||
log.Printf("Reading discarded message %v: %v\n", c, err)
|
||||
go conn.Close()
|
||||
|
|
@ -98,22 +106,15 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
var arrSID [32]byte
|
||||
copy(arrSID[:], SID)
|
||||
var sesh *mux.Session
|
||||
if sesh = sta.GetSession(arrSID); sesh == nil {
|
||||
sesh = mux.MakeSession(0, 1e9, 1e9, util.MakeObfs(SID), util.MakeDeobfs(SID), util.ReadTillDrain)
|
||||
sta.PutSession(arrSID, sesh)
|
||||
}
|
||||
// FIXME: the following code should not be executed for every single remote connection
|
||||
sesh := user.GetOrCreateSession(sessionID, util.MakeObfs(UID), util.MakeDeobfs(UID), util.ReadTLS)
|
||||
sesh.AddConnection(conn)
|
||||
go func() {
|
||||
for {
|
||||
newStream, err := sesh.AcceptStream()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get new stream: %v", err)
|
||||
if err == mux.ErrBrokenSession {
|
||||
sta.DelSession(arrSID)
|
||||
user.DelSession(sessionID)
|
||||
return
|
||||
} else {
|
||||
continue
|
||||
|
|
@ -127,16 +128,14 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
|
|||
go pipe(ssConn, newStream)
|
||||
go pipe(newStream, ssConn)
|
||||
}
|
||||
}()
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
//runtime.SetBlockProfileRate(5)
|
||||
//go func() {
|
||||
// log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
||||
//}()
|
||||
runtime.SetBlockProfileRate(5)
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
|
||||
}()
|
||||
// Should be 127.0.0.1 to listen to ss-server on this machine
|
||||
var localHost string
|
||||
// server_port in ss config, same as remotePort in plugin mode
|
||||
|
|
@ -181,7 +180,13 @@ func main() {
|
|||
localPort = strings.Split(*localAddr, ":")[1]
|
||||
log.Printf("Starting standalone mode, listening on %v:%v to ss at %v:%v\n", remoteHost, remotePort, localHost, localPort)
|
||||
}
|
||||
sta := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now)
|
||||
sta, _ := server.InitState(localHost, localPort, remoteHost, remotePort, time.Now, "userinfo.db")
|
||||
|
||||
//debug
|
||||
var arrUID [32]byte
|
||||
UID, _ := hex.DecodeString("50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c")
|
||||
copy(arrUID[:], UID)
|
||||
sta.Userpanel.AddNewUser(arrUID, 10, 1e12, 1e12, 1e12, 1e12)
|
||||
err := sta.ParseConfig(pluginOpts)
|
||||
if err != nil {
|
||||
log.Fatalf("Configuration file error: %v", err)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ func MakeRandomField(sta *State) []byte {
|
|||
rdm := make([]byte, 16)
|
||||
io.ReadFull(rand.Reader, rdm)
|
||||
preHash := make([]byte, 56)
|
||||
copy(preHash[0:32], sta.SID)
|
||||
copy(preHash[0:32], sta.UID)
|
||||
copy(preHash[32:40], t)
|
||||
copy(preHash[40:56], rdm)
|
||||
h := sha256.New()
|
||||
|
|
@ -33,9 +33,9 @@ func MakeRandomField(sta *State) []byte {
|
|||
}
|
||||
|
||||
func MakeSessionTicket(sta *State) []byte {
|
||||
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted SID 32 bytes][padding 128 bytes]
|
||||
// sessionTicket: [marshalled ephemeral pub key 32 bytes][encrypted UID+sessionID 36 bytes][padding 124 bytes]
|
||||
// The first 16 bytes of the marshalled ephemeral public key is used as the IV
|
||||
// for encrypting the SID
|
||||
// for encrypting the UID
|
||||
tthInterval := sta.Now().Unix() / int64(sta.TicketTimeHint)
|
||||
ec := ecdh.NewCurve25519ECDH()
|
||||
ephKP := sta.getKeyPair(tthInterval)
|
||||
|
|
@ -50,8 +50,21 @@ func MakeSessionTicket(sta *State) []byte {
|
|||
ticket := make([]byte, 192)
|
||||
copy(ticket[0:32], ec.Marshal(ephKP.PublicKey))
|
||||
key, _ := ec.GenerateSharedSecret(ephKP.PrivateKey, sta.staticPub)
|
||||
cipherSID := util.AESEncrypt(ticket[0:16], key, sta.SID)
|
||||
copy(ticket[32:64], cipherSID)
|
||||
copy(ticket[64:192], util.PsudoRandBytes(128, tthInterval+sta.opaque))
|
||||
plainUIDsID := make([]byte, 36)
|
||||
copy(plainUIDsID, sta.UID)
|
||||
binary.BigEndian.PutUint32(plainUIDsID[32:36], sta.sessionID)
|
||||
cipherUIDsID := util.AESEncrypt(ticket[0:16], key, plainUIDsID)
|
||||
copy(ticket[32:68], cipherUIDsID)
|
||||
// The purpose of adding sessionID is that, the generated padding of sessionTicket needs to be unpredictable.
|
||||
// As shown in auth.go, the padding is generated by a psudo random generator. The seed
|
||||
// needs to be the same for each TicketTimeHint interval. However the value of epoch/TicketTimeHint
|
||||
// is public knowledge, so is the psudo random algorithm used by math/rand. Therefore not only
|
||||
// can the firewall tell that the padding is generated in this specific way, this padding is identical
|
||||
// for all ckclients in the same TicketTimeHint interval. This will expose us.
|
||||
//
|
||||
// With the sessionID value generated at startup of ckclient and used as a part of the seed, the
|
||||
// sessionTicket is still identical for each TicketTimeHint interval, but others won't be able to know
|
||||
// how it was generated. It will also be different for each client.
|
||||
copy(ticket[68:192], util.PsudoRandBytes(124, tthInterval+int64(sta.sessionID)))
|
||||
return ticket
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ type State struct {
|
|||
SS_REMOTE_PORT string
|
||||
|
||||
Now func() time.Time
|
||||
opaque int64
|
||||
SID []byte
|
||||
sessionID uint32
|
||||
UID []byte
|
||||
staticPub crypto.PublicKey
|
||||
keyPairsM sync.RWMutex
|
||||
keyPairs map[int64]*keyPair
|
||||
|
|
@ -41,14 +41,14 @@ type State struct {
|
|||
NumConn int
|
||||
}
|
||||
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, opaque int64) *State {
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, sessionID uint32) *State {
|
||||
ret := &State{
|
||||
SS_LOCAL_HOST: localHost,
|
||||
SS_LOCAL_PORT: localPort,
|
||||
SS_REMOTE_HOST: remoteHost,
|
||||
SS_REMOTE_PORT: remotePort,
|
||||
Now: nowFunc,
|
||||
opaque: opaque,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
ret.keyPairs = make(map[int64]*keyPair)
|
||||
return ret
|
||||
|
|
@ -56,6 +56,7 @@ func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func
|
|||
|
||||
// semi-colon separated value. This is for Android plugin options
|
||||
func ssvToJson(ssv string) (ret []byte) {
|
||||
// TODO: base64 encoded data has =. How to escape?
|
||||
unescape := func(s string) string {
|
||||
r := strings.Replace(s, "\\\\", "\\", -1)
|
||||
r = strings.Replace(r, "\\=", "=", -1)
|
||||
|
|
@ -104,16 +105,16 @@ func (sta *State) ParseConfig(conf string) (err error) {
|
|||
sta.TicketTimeHint = preParse.TicketTimeHint
|
||||
sta.MaskBrowser = preParse.MaskBrowser
|
||||
sta.NumConn = preParse.NumConn
|
||||
sid, pub, err := parseKey(preParse.Key)
|
||||
uid, pub, err := parseKey(preParse.Key)
|
||||
if err != nil {
|
||||
return errors.New("Failed to parse Key: " + err.Error())
|
||||
}
|
||||
sta.SID = sid
|
||||
sta.UID = uid
|
||||
sta.staticPub = pub
|
||||
return nil
|
||||
}
|
||||
|
||||
// Structure: [SID 32 bytes][marshalled public key 32 bytes]
|
||||
// Structure: [UID 32 bytes][marshalled public key 32 bytes]
|
||||
func parseKey(b64 string) ([]byte, crypto.PublicKey, error) {
|
||||
b, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ import (
|
|||
// make sure packets arrive in order.
|
||||
//
|
||||
// Cloak packets will have a 32-bit sequence number on them, so we know in which order
|
||||
// they should be sent to shadowsocks. In the case that the packets arrive out-of-order,
|
||||
// the code in this file provides buffering and sorting.
|
||||
// they should be sent to shadowsocks. The code in this file provides buffering and sorting.
|
||||
//
|
||||
// Similar to TCP, the next seq number after 2^32-1 is 0. This is called wrap around.
|
||||
//
|
||||
|
|
@ -54,6 +53,12 @@ func (sh *sorterHeap) Pop() interface{} {
|
|||
return x
|
||||
}
|
||||
|
||||
func (s *Stream) writeNewFrame(f *Frame) {
|
||||
s.newFrameCh <- f
|
||||
}
|
||||
|
||||
// recvNewFrame is a forever running loop which receives frames unordered,
|
||||
// cache and order them and send them into sortedBufCh
|
||||
func (s *Stream) recvNewFrame() {
|
||||
for {
|
||||
var f *Frame
|
||||
|
|
@ -69,7 +74,7 @@ func (s *Stream) recvNewFrame() {
|
|||
|
||||
if len(s.sh) == 0 && f.Seq == s.nextRecvSeq {
|
||||
if f.Closing == 1 {
|
||||
s.passiveClose()
|
||||
s.sortedBufCh <- []byte{}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +120,7 @@ func (s *Stream) recvNewFrame() {
|
|||
frame := heap.Pop(&s.sh).(*frameNode).frame
|
||||
|
||||
if frame.Closing == 1 {
|
||||
s.passiveClose()
|
||||
s.sortedBufCh <- []byte{}
|
||||
return
|
||||
}
|
||||
payload := frame.Payload
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
package multiplex
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/juju/ratelimit"
|
||||
)
|
||||
|
||||
// Valve needs to be universal, across all sessions that belong to a user
|
||||
// gabe please don't sue
|
||||
type Valve struct {
|
||||
// traffic directions from the server's perspective are refered
|
||||
// exclusively as rx and tx.
|
||||
// rx is from client to server, tx is from server to client
|
||||
// DO NOT use terms up or down as this is used in usermanager
|
||||
// for bandwidth limiting
|
||||
rxtb atomic.Value // *ratelimit.Bucket
|
||||
txtb atomic.Value // *ratelimit.Bucket
|
||||
|
||||
rxCredit int64
|
||||
txCredit int64
|
||||
}
|
||||
|
||||
func MakeValve(rxRate, txRate, rxCredit, txCredit int64) *Valve {
|
||||
v := &Valve{
|
||||
rxCredit: rxCredit,
|
||||
txCredit: txCredit,
|
||||
}
|
||||
v.SetRxRate(rxRate)
|
||||
v.SetTxRate(txRate)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *Valve) SetRxRate(rate int64) {
|
||||
v.rxtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate))
|
||||
}
|
||||
|
||||
func (v *Valve) SetTxRate(rate int64) {
|
||||
v.txtb.Store(ratelimit.NewBucketWithRate(float64(rate), rate))
|
||||
}
|
||||
|
||||
func (v *Valve) rxWait(n int) {
|
||||
v.rxtb.Load().(*ratelimit.Bucket).Wait(int64(n))
|
||||
}
|
||||
|
||||
func (v *Valve) txWait(n int) {
|
||||
v.txtb.Load().(*ratelimit.Bucket).Wait(int64(n))
|
||||
}
|
||||
|
||||
// n can be negative
|
||||
func (v *Valve) AddRxCredit(n int64) int64 {
|
||||
return atomic.AddInt64(&v.rxCredit, n)
|
||||
}
|
||||
|
||||
// n can be negative
|
||||
func (v *Valve) AddTxCredit(n int64) int64 {
|
||||
return atomic.AddInt64(&v.txCredit, n)
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package multiplex
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
|
@ -16,14 +17,14 @@ var ErrBrokenSession = errors.New("broken session")
|
|||
var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
||||
|
||||
type Session struct {
|
||||
id int
|
||||
id uint32 // This field isn't acutally used
|
||||
|
||||
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
|
||||
obfs func(*Frame) []byte
|
||||
// Remove TLS header, decrypt and unmarshall multiplexing headers
|
||||
deobfs func([]byte) *Frame
|
||||
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
|
||||
obfsedReader func(net.Conn, []byte) (int, error)
|
||||
obfsedRead func(net.Conn, []byte) (int, error)
|
||||
|
||||
// atomic
|
||||
nextStreamID uint32
|
||||
|
|
@ -37,24 +38,25 @@ type Session struct {
|
|||
// For accepting new streams
|
||||
acceptCh chan *Stream
|
||||
|
||||
// TODO: use sync.Once for this
|
||||
closingM sync.Mutex
|
||||
die chan struct{}
|
||||
closing bool
|
||||
}
|
||||
|
||||
// 1 conn is needed to make a session
|
||||
func MakeSession(id int, uprate, downrate float64, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedReader func(net.Conn, []byte) (int, error)) *Session {
|
||||
func MakeSession(id uint32, valve *Valve, obfs func(*Frame) []byte, deobfs func([]byte) *Frame, obfsedRead func(net.Conn, []byte) (int, error)) *Session {
|
||||
sesh := &Session{
|
||||
id: id,
|
||||
obfs: obfs,
|
||||
deobfs: deobfs,
|
||||
obfsedReader: obfsedReader,
|
||||
obfsedRead: obfsedRead,
|
||||
nextStreamID: 1,
|
||||
streams: make(map[uint32]*Stream),
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
die: make(chan struct{}),
|
||||
}
|
||||
sesh.sb = makeSwitchboard(sesh, uprate, downrate)
|
||||
sesh.sb = makeSwitchboard(sesh, valve)
|
||||
return sesh
|
||||
}
|
||||
|
||||
|
|
@ -63,12 +65,18 @@ func (sesh *Session) AddConnection(conn net.Conn) {
|
|||
}
|
||||
|
||||
func (sesh *Session) OpenStream() (*Stream, error) {
|
||||
id := atomic.AddUint32(&sesh.nextStreamID, 1)
|
||||
id -= 1 // Because atomic.AddUint32 returns the value after incrementation
|
||||
select {
|
||||
case <-sesh.die:
|
||||
return nil, ErrBrokenSession
|
||||
default:
|
||||
}
|
||||
id := atomic.AddUint32(&sesh.nextStreamID, 1) - 1
|
||||
// Because atomic.AddUint32 returns the value after incrementation
|
||||
stream := makeStream(id, sesh)
|
||||
sesh.streamsM.Lock()
|
||||
sesh.streams[id] = stream
|
||||
sesh.streamsM.Unlock()
|
||||
log.Printf("Opening stream %v\n", id)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
|
|
@ -108,6 +116,7 @@ func (sesh *Session) addStream(id uint32) *Stream {
|
|||
sesh.streams[id] = stream
|
||||
sesh.streamsM.Unlock()
|
||||
sesh.acceptCh <- stream
|
||||
log.Printf("Adding stream %v\n", id)
|
||||
return stream
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ type Stream struct {
|
|||
// atomic
|
||||
nextSendSeq uint32
|
||||
|
||||
closingM sync.Mutex
|
||||
closingM sync.RWMutex
|
||||
// close(die) is used to notify different goroutines that this stream is closing
|
||||
die chan struct{}
|
||||
// to prevent closing a closed channel
|
||||
|
|
@ -45,7 +45,7 @@ func makeStream(id uint32, sesh *Session) *Stream {
|
|||
die: make(chan struct{}),
|
||||
sh: []*frameNode{},
|
||||
newFrameCh: make(chan *Frame, 1024),
|
||||
sortedBufCh: make(chan []byte, 4096),
|
||||
sortedBufCh: make(chan []byte, 1024),
|
||||
}
|
||||
go stream.recvNewFrame()
|
||||
return stream
|
||||
|
|
@ -64,6 +64,10 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
case <-stream.die:
|
||||
return 0, errBrokenStream
|
||||
case data := <-stream.sortedBufCh:
|
||||
if len(data) == 0 {
|
||||
stream.passiveClose()
|
||||
return 0, errBrokenStream
|
||||
}
|
||||
if len(buf) < len(data) {
|
||||
log.Println(len(data))
|
||||
return 0, errors.New("buf too small")
|
||||
|
|
@ -75,6 +79,13 @@ func (stream *Stream) Read(buf []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (stream *Stream) Write(in []byte) (n int, err error) {
|
||||
// RWMutex used here isn't really for RW.
|
||||
// we use it to exploit the fact that RLock doesn't create contention.
|
||||
// The use of RWMutex is so that the stream will not actively close
|
||||
// in the middle of the execution of Write. This may cause the closing frame
|
||||
// to be sent before the data frame and cause loss of packet.
|
||||
stream.closingM.RLock()
|
||||
defer stream.closingM.RUnlock()
|
||||
select {
|
||||
case <-stream.die:
|
||||
return 0, errBrokenStream
|
||||
|
|
@ -83,13 +94,11 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
|||
|
||||
f := &Frame{
|
||||
StreamID: stream.id,
|
||||
Seq: atomic.LoadUint32(&stream.nextSendSeq),
|
||||
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
|
||||
Closing: 0,
|
||||
Payload: in,
|
||||
}
|
||||
|
||||
atomic.AddUint32(&stream.nextSendSeq, 1)
|
||||
|
||||
tlsRecord := stream.session.obfs(f)
|
||||
n, err = stream.session.sb.send(tlsRecord)
|
||||
|
||||
|
|
@ -97,9 +106,7 @@ func (stream *Stream) Write(in []byte) (n int, err error) {
|
|||
|
||||
}
|
||||
|
||||
// only close locally. Used when the stream close is notified by the remote
|
||||
func (stream *Stream) passiveClose() error {
|
||||
|
||||
func (stream *Stream) shutdown() error {
|
||||
// Lock here because closing a closed channel causes panic
|
||||
stream.closingM.Lock()
|
||||
defer stream.closingM.Unlock()
|
||||
|
|
@ -108,29 +115,36 @@ func (stream *Stream) passiveClose() error {
|
|||
}
|
||||
stream.closing = true
|
||||
close(stream.die)
|
||||
return nil
|
||||
}
|
||||
|
||||
// only close locally. Used when the stream close is notified by the remote
|
||||
func (stream *Stream) passiveClose() error {
|
||||
err := stream.shutdown()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream.session.delStream(stream.id)
|
||||
log.Printf("%v passive closing\n", stream.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// active close. Close locally and tell the remote that this stream is being closed
|
||||
func (stream *Stream) Close() error {
|
||||
|
||||
// Lock here because closing a closed channel causes panic
|
||||
stream.closingM.Lock()
|
||||
defer stream.closingM.Unlock()
|
||||
if stream.closing {
|
||||
return errRepeatStreamClosing
|
||||
err := stream.shutdown()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream.closing = true
|
||||
close(stream.die)
|
||||
|
||||
// Notify remote that this stream is closed
|
||||
prand.Seed(int64(stream.id))
|
||||
padLen := int(math.Floor(prand.Float64()*200 + 300))
|
||||
pad := make([]byte, padLen)
|
||||
prand.Read(pad)
|
||||
f := &Frame{
|
||||
StreamID: stream.id,
|
||||
Seq: atomic.LoadUint32(&stream.nextSendSeq),
|
||||
Seq: atomic.AddUint32(&stream.nextSendSeq, 1) - 1,
|
||||
Closing: 1,
|
||||
Payload: pad,
|
||||
}
|
||||
|
|
@ -138,20 +152,12 @@ func (stream *Stream) Close() error {
|
|||
stream.session.sb.send(tlsRecord)
|
||||
|
||||
stream.session.delStream(stream.id)
|
||||
log.Printf("%v actively closed\n", stream.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Same as Close() but no call to session.delStream.
|
||||
// This is called in session.Close() to avoid mutex deadlock
|
||||
func (stream *Stream) closeNoDelMap() error {
|
||||
|
||||
// Lock here because closing a closed channel causes panic
|
||||
stream.closingM.Lock()
|
||||
defer stream.closingM.Unlock()
|
||||
if stream.closing {
|
||||
return errRepeatStreamClosing
|
||||
}
|
||||
stream.closing = true
|
||||
close(stream.die)
|
||||
return nil
|
||||
return stream.shutdown()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,20 +6,34 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/juju/ratelimit"
|
||||
)
|
||||
|
||||
// switchboard is responsible for keeping the reference of TLS connections between client and server
|
||||
type switchboard struct {
|
||||
session *Session
|
||||
|
||||
wtb *ratelimit.Bucket
|
||||
rtb *ratelimit.Bucket
|
||||
*Valve
|
||||
|
||||
optimum atomic.Value
|
||||
// optimum is the connEnclave with the smallest sendQueue
|
||||
optimum atomic.Value // *connEnclave
|
||||
cesM sync.RWMutex
|
||||
ces []*connEnclave
|
||||
|
||||
//debug
|
||||
hM sync.Mutex
|
||||
used map[uint32]bool
|
||||
}
|
||||
|
||||
func (sb *switchboard) getOptimum() *connEnclave {
|
||||
if i := sb.optimum.Load(); i == nil {
|
||||
return nil
|
||||
} else {
|
||||
return i.(*connEnclave)
|
||||
}
|
||||
}
|
||||
|
||||
func (sb *switchboard) setOptimum(ce *connEnclave) {
|
||||
sb.optimum.Store(ce)
|
||||
}
|
||||
|
||||
// Some data comes from a Stream to be sent through one of the many
|
||||
|
|
@ -27,45 +41,51 @@ type switchboard struct {
|
|||
//
|
||||
// In this case, we pick the remoteConn that has about the smallest sendQueue.
|
||||
type connEnclave struct {
|
||||
sb *switchboard
|
||||
remoteConn net.Conn
|
||||
sendQueue uint32
|
||||
}
|
||||
|
||||
// It takes at least 1 conn to start a switchboard
|
||||
// TODO: does it really?
|
||||
func makeSwitchboard(sesh *Session, uprate, downrate float64) *switchboard {
|
||||
func makeSwitchboard(sesh *Session, valve *Valve) *switchboard {
|
||||
// rates are uint64 because in the usermanager we want the bandwidth to be atomically
|
||||
// operated (so that the bandwidth can change on the fly).
|
||||
sb := &switchboard{
|
||||
session: sesh,
|
||||
wtb: ratelimit.NewBucketWithRate(uprate, int64(uprate)),
|
||||
rtb: ratelimit.NewBucketWithRate(downrate, int64(downrate)),
|
||||
Valve: valve,
|
||||
ces: []*connEnclave{},
|
||||
used: make(map[uint32]bool),
|
||||
}
|
||||
return sb
|
||||
}
|
||||
|
||||
var errNilOptimum error = errors.New("The optimal connection is nil")
|
||||
|
||||
var ErrNoRxCredit error = errors.New("No Rx credit is left")
|
||||
var ErrNoTxCredit error = errors.New("No Tx credit is left")
|
||||
|
||||
func (sb *switchboard) send(data []byte) (int, error) {
|
||||
ce := sb.optimum.Load().(*connEnclave)
|
||||
ce := sb.getOptimum()
|
||||
if ce == nil {
|
||||
return 0, errNilOptimum
|
||||
}
|
||||
sb.wtb.Wait(int64(len(data)))
|
||||
atomic.AddUint32(&ce.sendQueue, uint32(len(data)))
|
||||
go sb.updateOptimum()
|
||||
n, err := ce.remoteConn.Write(data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return n, err
|
||||
// TODO
|
||||
}
|
||||
if sb.AddTxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoTxCredit)
|
||||
defer sb.session.Close()
|
||||
return n, ErrNoTxCredit
|
||||
}
|
||||
atomic.AddUint32(&ce.sendQueue, ^uint32(n-1))
|
||||
go sb.updateOptimum()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (sb *switchboard) updateOptimum() {
|
||||
currentOpti := sb.optimum.Load().(*connEnclave)
|
||||
currentOpti := sb.getOptimum()
|
||||
currentOptiQ := atomic.LoadUint32(¤tOpti.sendQueue)
|
||||
sb.cesM.RLock()
|
||||
for _, ce := range sb.ces {
|
||||
|
|
@ -76,20 +96,18 @@ func (sb *switchboard) updateOptimum() {
|
|||
}
|
||||
}
|
||||
sb.cesM.RUnlock()
|
||||
sb.optimum.Store(currentOpti)
|
||||
sb.setOptimum(currentOpti)
|
||||
}
|
||||
|
||||
func (sb *switchboard) addConn(conn net.Conn) {
|
||||
|
||||
newCe := &connEnclave{
|
||||
sb: sb,
|
||||
remoteConn: conn,
|
||||
sendQueue: 0,
|
||||
}
|
||||
sb.cesM.Lock()
|
||||
sb.ces = append(sb.ces, newCe)
|
||||
sb.cesM.Unlock()
|
||||
sb.optimum.Store(newCe)
|
||||
sb.setOptimum(newCe)
|
||||
go sb.deplex(newCe)
|
||||
}
|
||||
|
||||
|
|
@ -101,10 +119,10 @@ func (sb *switchboard) removeConn(closing *connEnclave) {
|
|||
break
|
||||
}
|
||||
}
|
||||
sb.cesM.Unlock()
|
||||
if len(sb.ces) == 0 {
|
||||
sb.session.Close()
|
||||
}
|
||||
sb.cesM.Unlock()
|
||||
}
|
||||
|
||||
func (sb *switchboard) shutdown() {
|
||||
|
|
@ -118,19 +136,40 @@ func (sb *switchboard) shutdown() {
|
|||
func (sb *switchboard) deplex(ce *connEnclave) {
|
||||
buf := make([]byte, 20480)
|
||||
for {
|
||||
i, err := sb.session.obfsedReader(ce.remoteConn, buf)
|
||||
sb.rtb.Wait(int64(i))
|
||||
n, err := sb.session.obfsedRead(ce.remoteConn, buf)
|
||||
sb.rxWait(n)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
go ce.remoteConn.Close()
|
||||
sb.removeConn(ce)
|
||||
return
|
||||
}
|
||||
frame := sb.session.deobfs(buf[:i])
|
||||
if sb.AddRxCredit(-int64(n)) < 0 {
|
||||
log.Println(ErrNoRxCredit)
|
||||
sb.session.Close()
|
||||
return
|
||||
}
|
||||
frame := sb.session.deobfs(buf[:n])
|
||||
|
||||
//debug
|
||||
|
||||
var stream *Stream
|
||||
if stream = sb.session.getStream(frame.StreamID); stream == nil {
|
||||
if frame.Closing == 1 {
|
||||
// if the frame is telling us to close a closed stream
|
||||
// (this happens when ss-server and ss-local closes the stream
|
||||
// simutaneously), we don't do anything
|
||||
continue
|
||||
}
|
||||
//debug
|
||||
sb.hM.Lock()
|
||||
if sb.used[frame.StreamID] {
|
||||
log.Printf("%v lost!\n", frame.StreamID)
|
||||
}
|
||||
sb.used[frame.StreamID] = true
|
||||
sb.hM.Unlock()
|
||||
stream = sb.session.addStream(frame.StreamID)
|
||||
}
|
||||
stream.newFrameCh <- frame
|
||||
stream.writeNewFrame(frame)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,54 +11,54 @@ import (
|
|||
ecdh "github.com/cbeuw/go-ecdh"
|
||||
)
|
||||
|
||||
// input ticket, return SID
|
||||
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, error) {
|
||||
// input ticket, return UID
|
||||
func decryptSessionTicket(staticPv crypto.PrivateKey, ticket []byte) ([]byte, uint32, error) {
|
||||
ec := ecdh.NewCurve25519ECDH()
|
||||
ephPub, _ := ec.Unmarshal(ticket[0:32])
|
||||
key, err := ec.GenerateSharedSecret(staticPv, ephPub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
SID := util.AESDecrypt(ticket[0:16], key, ticket[32:64])
|
||||
return SID, nil
|
||||
UIDsID := util.AESDecrypt(ticket[0:16], key, ticket[32:68])
|
||||
sessionID := binary.BigEndian.Uint32(UIDsID[32:36])
|
||||
return UIDsID[0:32], sessionID, nil
|
||||
}
|
||||
|
||||
func validateRandom(random []byte, SID []byte, time int64) bool {
|
||||
func validateRandom(random []byte, UID []byte, time int64) bool {
|
||||
t := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(t, uint64(time/(12*60*60)))
|
||||
rdm := random[0:16]
|
||||
preHash := make([]byte, 56)
|
||||
copy(preHash[0:32], SID)
|
||||
copy(preHash[0:32], UID)
|
||||
copy(preHash[32:40], t)
|
||||
copy(preHash[40:56], rdm)
|
||||
h := sha256.New()
|
||||
h.Write(preHash)
|
||||
return bytes.Equal(h.Sum(nil)[0:16], random[16:32])
|
||||
}
|
||||
func TouchStone(ch *ClientHello, sta *State) (bool, []byte) {
|
||||
func TouchStone(ch *ClientHello, sta *State) (isSS bool, UID []byte, sessionID uint32) {
|
||||
var random [32]byte
|
||||
copy(random[:], ch.random)
|
||||
used := sta.getUsedRandom(random)
|
||||
if used != 0 {
|
||||
log.Println("Replay! Duplicate random")
|
||||
return false, nil
|
||||
return false, nil, 0
|
||||
}
|
||||
sta.putUsedRandom(random)
|
||||
|
||||
ticket := ch.extensions[[2]byte{0x00, 0x23}]
|
||||
if len(ticket) < 64 {
|
||||
return false, nil
|
||||
return false, nil, 0
|
||||
}
|
||||
SID, err := decryptSessionTicket(sta.staticPv, ticket)
|
||||
UID, sessionID, err := decryptSessionTicket(sta.staticPv, ticket)
|
||||
if err != nil {
|
||||
log.Printf("ts: %v\n", err)
|
||||
return false, nil
|
||||
return false, nil, 0
|
||||
}
|
||||
log.Printf("SID: %x\n", SID)
|
||||
isSS := validateRandom(ch.random, SID, sta.Now().Unix())
|
||||
isSS = validateRandom(ch.random, UID, sta.Now().Unix())
|
||||
if !isSS {
|
||||
return false, nil
|
||||
return false, nil, 0
|
||||
}
|
||||
|
||||
return true, SID
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"github.com/cbeuw/Cloak/internal/server/usermanager"
|
||||
)
|
||||
|
||||
type rawConfig struct {
|
||||
|
|
@ -31,25 +31,28 @@ type State struct {
|
|||
|
||||
Now func() time.Time
|
||||
staticPv crypto.PrivateKey
|
||||
Userpanel *usermanager.Userpanel
|
||||
usedRandomM sync.RWMutex
|
||||
usedRandom map[[32]byte]int
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[[32]byte]*mux.Session
|
||||
|
||||
WebServerAddr string
|
||||
}
|
||||
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time) *State {
|
||||
func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func() time.Time, dbPath string) (*State, error) {
|
||||
up, err := usermanager.MakeUserpanel(dbPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := &State{
|
||||
SS_LOCAL_HOST: localHost,
|
||||
SS_LOCAL_PORT: localPort,
|
||||
SS_REMOTE_HOST: remoteHost,
|
||||
SS_REMOTE_PORT: remotePort,
|
||||
Now: nowFunc,
|
||||
Userpanel: up,
|
||||
}
|
||||
ret.usedRandom = make(map[[32]byte]int)
|
||||
ret.sessions = make(map[[32]byte]*mux.Session)
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// semi-colon separated value.
|
||||
|
|
@ -115,28 +118,6 @@ func (sta *State) ParseConfig(conf string) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sta *State) GetSession(SID [32]byte) *mux.Session {
|
||||
sta.sessionsM.RLock()
|
||||
defer sta.sessionsM.RUnlock()
|
||||
if sesh, ok := sta.sessions[SID]; ok {
|
||||
return sesh
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (sta *State) PutSession(SID [32]byte, sesh *mux.Session) {
|
||||
sta.sessionsM.Lock()
|
||||
sta.sessions[SID] = sesh
|
||||
sta.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (sta *State) DelSession(SID [32]byte) {
|
||||
sta.sessionsM.Lock()
|
||||
delete(sta.sessions, SID)
|
||||
sta.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (sta *State) getUsedRandom(random [32]byte) int {
|
||||
sta.usedRandomM.Lock()
|
||||
defer sta.usedRandomM.Unlock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
mux "github.com/cbeuw/Cloak/internal/multiplex"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
/*
|
||||
type userParams struct {
|
||||
sessionsCap uint32
|
||||
upRate int64
|
||||
downRate int64
|
||||
upCredit int64
|
||||
downCredit int64
|
||||
}
|
||||
*/
|
||||
|
||||
type user struct {
|
||||
up *Userpanel
|
||||
|
||||
uid [32]byte
|
||||
|
||||
sessionsCap uint32 //userParams
|
||||
|
||||
valve *mux.Valve
|
||||
|
||||
sessionsM sync.RWMutex
|
||||
sessions map[uint32]*mux.Session
|
||||
}
|
||||
|
||||
func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) *user {
|
||||
valve := mux.MakeValve(upRate, downRate, upCredit, downCredit)
|
||||
u := &user{
|
||||
up: up,
|
||||
uid: uid,
|
||||
valve: valve,
|
||||
sessionsCap: sessionsCap,
|
||||
sessions: make(map[uint32]*mux.Session),
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *user) setSessionsCap(cap uint32) {
|
||||
atomic.StoreUint32(&u.sessionsCap, cap)
|
||||
}
|
||||
|
||||
func (u *user) GetSession(sessionID uint32) *mux.Session {
|
||||
u.sessionsM.RLock()
|
||||
defer u.sessionsM.RUnlock()
|
||||
if sesh, ok := u.sessions[sessionID]; ok {
|
||||
return sesh
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *user) PutSession(sessionID uint32, sesh *mux.Session) {
|
||||
u.sessionsM.Lock()
|
||||
u.sessions[sessionID] = sesh
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) DelSession(sessionID uint32) {
|
||||
u.sessionsM.Lock()
|
||||
delete(u.sessions, sessionID)
|
||||
if len(u.sessions) == 0 {
|
||||
u.sessionsM.Unlock()
|
||||
u.up.delActiveUser(u.uid)
|
||||
return
|
||||
}
|
||||
u.sessionsM.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session) {
|
||||
log.Printf("getting sessionID %v\n", sessionID)
|
||||
if sesh = u.GetSession(sessionID); sesh != nil {
|
||||
return
|
||||
} else {
|
||||
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
|
||||
u.PutSession(sessionID, sesh)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
package usermanager
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/boltdb/bolt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Userpanel struct {
|
||||
db *bolt.DB
|
||||
|
||||
activeUsersM sync.RWMutex
|
||||
activeUsers map[[32]byte]*user
|
||||
}
|
||||
|
||||
func MakeUserpanel(dbPath string) (*Userpanel, error) {
|
||||
db, err := bolt.Open(dbPath, 0600, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
up := &Userpanel{
|
||||
db: db,
|
||||
activeUsers: make(map[[32]byte]*user),
|
||||
}
|
||||
return up, nil
|
||||
}
|
||||
|
||||
var ErrUserNotFound = errors.New("User does not exist in memory or db")
|
||||
|
||||
// GetUser is used to retrieve a user if s/he is active, or to retrieve the user's infor
|
||||
// from the db and mark it as an active user
|
||||
func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*user, error) {
|
||||
up.activeUsersM.RLock()
|
||||
if user, ok := up.activeUsers[UID]; ok {
|
||||
up.activeUsersM.RUnlock()
|
||||
return user, nil
|
||||
}
|
||||
up.activeUsersM.RUnlock()
|
||||
|
||||
var sessionsCap uint32
|
||||
var upRate, downRate, upCredit, downCredit int64
|
||||
err := up.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
sessionsCap = binary.BigEndian.Uint32(b.Get([]byte("sessionsCap")))
|
||||
upRate = int64(binary.BigEndian.Uint64(b.Get([]byte("upRate"))))
|
||||
downRate = int64(binary.BigEndian.Uint64(b.Get([]byte("downRate"))))
|
||||
upCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("upCredit")))) // reee brackets
|
||||
downCredit = int64(binary.BigEndian.Uint64(b.Get([]byte("downCredit"))))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: put all of these parameters in a struct instead
|
||||
u := MakeUser(up, UID, sessionsCap, upRate, downRate, upCredit, downCredit)
|
||||
up.activeUsersM.Lock()
|
||||
up.activeUsers[UID] = u
|
||||
up.activeUsersM.Unlock()
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (up *Userpanel) AddNewUser(UID [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket(UID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// FIXME: obnoxious code
|
||||
quad := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(quad, sessionsCap)
|
||||
if err = b.Put([]byte("sessionsCap"), quad); err != nil {
|
||||
return err
|
||||
}
|
||||
oct := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(oct, uint64(upRate))
|
||||
if err = b.Put([]byte("upRate"), oct); err != nil {
|
||||
return err
|
||||
}
|
||||
binary.BigEndian.PutUint64(oct, uint64(downRate))
|
||||
if err = b.Put([]byte("downRate"), oct); err != nil {
|
||||
return err
|
||||
}
|
||||
binary.BigEndian.PutUint64(oct, uint64(upCredit))
|
||||
if err = b.Put([]byte("upCredit"), oct); err != nil {
|
||||
return err
|
||||
}
|
||||
binary.BigEndian.PutUint64(oct, uint64(downCredit))
|
||||
if err = b.Put([]byte("downCredit"), oct); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryUint32(UID [32]byte, key string, value uint32) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
quad := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(quad, value)
|
||||
if err := b.Put([]byte(key), quad); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (up *Userpanel) updateDBEntryInt64(UID [32]byte, key string, value int64) error {
|
||||
err := up.db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(UID[:])
|
||||
if b == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
oct := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(oct, uint64(value))
|
||||
if err := b.Put([]byte(key), oct); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// This is used when all sessions of a user close
|
||||
func (up *Userpanel) delActiveUser(UID [32]byte) {
|
||||
up.activeUsersM.Lock()
|
||||
delete(up.activeUsers, UID)
|
||||
up.activeUsersM.Unlock()
|
||||
}
|
||||
|
||||
func (up *Userpanel) getActiveUser(UID [32]byte) *user {
|
||||
up.activeUsersM.RLock()
|
||||
defer up.activeUsersM.RUnlock()
|
||||
return up.activeUsers[UID]
|
||||
}
|
||||
|
||||
func (up *Userpanel) SetSessionsCap(UID [32]byte, newSessionsCap uint32) error {
|
||||
if u := up.getActiveUser(UID); u != nil {
|
||||
u.setSessionsCap(newSessionsCap)
|
||||
}
|
||||
err := up.updateDBEntryUint32(UID, "sessionsCap", newSessionsCap)
|
||||
return err
|
||||
}
|
||||
|
|
@ -9,12 +9,13 @@ import (
|
|||
|
||||
// For each frame, the three parts of the header is xored with three keys.
|
||||
// The keys are generated from the SID and the payload of the frame.
|
||||
func genXorKeys(SID []byte, data []byte) (i uint32, ii uint32, iii uint32) {
|
||||
// FIXME: this code will panic if len(data)<18.
|
||||
func genXorKeys(secret []byte, data []byte) (i uint32, ii uint32, iii uint32) {
|
||||
h := xxhash.New32()
|
||||
ret := make([]uint32, 3)
|
||||
preHash := make([]byte, 16)
|
||||
for j := 0; j < 3; j++ {
|
||||
copy(preHash[0:10], SID[j*10:j*10+10])
|
||||
copy(preHash[0:10], secret[j*10:j*10+10])
|
||||
copy(preHash[10:16], data[j*6:j*6+6])
|
||||
h.Write(preHash)
|
||||
ret[j] = h.Sum32()
|
||||
|
|
|
|||
|
|
@ -43,14 +43,14 @@ func BtoInt(b []byte) int {
|
|||
|
||||
// PsudoRandBytes returns a byte slice filled with psudorandom bytes generated by the seed
|
||||
func PsudoRandBytes(length int, seed int64) []byte {
|
||||
prand.Seed(seed)
|
||||
r := prand.New(prand.NewSource(seed))
|
||||
ret := make([]byte, length)
|
||||
prand.Read(ret)
|
||||
r.Read(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
// ReadTillDrain reads TLS data according to its record layer
|
||||
func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
|
||||
// ReadTLS reads TLS data according to its record layer
|
||||
func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
|
||||
// TCP is a stream. Multiple TLS messages can arrive at the same time,
|
||||
// a single message can also be segmented due to MTU of the IP layer.
|
||||
// This function guareentees a single TLS message to be read and everything
|
||||
|
|
|
|||
Loading…
Reference in New Issue