diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index e495c8c..893895e 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -10,6 +10,7 @@ import ( "math/rand" "net" "os" + "sync" "time" "github.com/cbeuw/Cloak/internal/client" @@ -150,15 +151,21 @@ func main() { deobfs := util.MakeDeobfs(sta.UID) sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS) + var wg sync.WaitGroup // TODO: use sync group for i := 0; i < sta.NumConn; i++ { - conn, err := makeRemoteConn(sta) - if err != nil { - log.Printf("Failed to establish new connections to remote: %v\n", err) - return - } - sesh.AddConnection(conn) + wg.Add(1) + go func() { + conn, err := makeRemoteConn(sta) + if err != nil { + log.Printf("Failed to establish new connections to remote: %v\n", err) + return + } + sesh.AddConnection(conn) + wg.Done() + }() } + wg.Wait() listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) if err != nil { diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index a4decb1..d8f8820 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -81,7 +81,6 @@ func dispatchConnection(conn net.Conn, sta *server.State) { var arrUID [32]byte copy(arrUID[:], UID) user, err := sta.Userpanel.GetAndActivateUser(arrUID) - log.Printf("UID: %x\n", UID) if err != nil { log.Printf("+1 unauthorised user from %v, uid: %x\n", conn.RemoteAddr(), UID) goWeb(data) @@ -106,27 +105,31 @@ func dispatchConnection(conn net.Conn, sta *server.State) { } } - // FIXME: the following code should not be executed for every single remote connection - sesh := user.GetOrCreateSession(sessionID, util.MakeObfs(UID), util.MakeDeobfs(UID), util.ReadTLS) - sesh.AddConnection(conn) - for { - newStream, err := sesh.AcceptStream() - if err != nil { - log.Printf("Failed to get new stream: %v", err) - if err == mux.ErrBrokenSession { - user.DelSession(sessionID) - return - } else { + if sesh, existing := user.GetOrCreateSession(sessionID, util.MakeObfs(UID), util.MakeDeobfs(UID), util.ReadTLS); existing { + sesh.AddConnection(conn) + return + } else { + log.Printf("UID: %x\n", UID) + sesh.AddConnection(conn) + for { + newStream, err := sesh.AcceptStream() + if err != nil { + log.Printf("Failed to get new stream: %v", err) + if err == mux.ErrBrokenSession { + user.DelSession(sessionID) + return + } else { + continue + } + } + ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) + if err != nil { + log.Printf("Failed to connect to ssserver: %v", err) continue } + go pipe(ssConn, newStream) + go pipe(newStream, ssConn) } - ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) - if err != nil { - log.Printf("Failed to connect to ssserver: %v", err) - continue - } - go pipe(ssConn, newStream) - go pipe(newStream, ssConn) } } diff --git a/internal/server/usermanager/user.go b/internal/server/usermanager/user.go index 7b0e72e..4c92b9e 100644 --- a/internal/server/usermanager/user.go +++ b/internal/server/usermanager/user.go @@ -1,11 +1,12 @@ package usermanager import ( - mux "github.com/cbeuw/Cloak/internal/multiplex" "log" "net" "sync" "sync/atomic" + + mux "github.com/cbeuw/Cloak/internal/multiplex" ) /* @@ -18,7 +19,7 @@ type userParams struct { } */ -type user struct { +type User struct { up *Userpanel uid [32]byte @@ -31,9 +32,9 @@ type user struct { sessions map[uint32]*mux.Session } -func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) *user { +func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, upCredit, downCredit int64) *User { valve := mux.MakeValve(upRate, downRate, upCredit, downCredit) - u := &user{ + u := &User{ up: up, uid: uid, valve: valve, @@ -43,27 +44,23 @@ func MakeUser(up *Userpanel, uid [32]byte, sessionsCap uint32, upRate, downRate, return u } -func (u *user) setSessionsCap(cap uint32) { +func (u *User) setSessionsCap(cap uint32) { atomic.StoreUint32(&u.sessionsCap, cap) } -func (u *user) GetSession(sessionID uint32) *mux.Session { +func (u *User) GetSession(sessionID uint32) *mux.Session { u.sessionsM.RLock() defer u.sessionsM.RUnlock() - if sesh, ok := u.sessions[sessionID]; ok { - return sesh - } else { - return nil - } + return u.sessions[sessionID] } -func (u *user) PutSession(sessionID uint32, sesh *mux.Session) { +func (u *User) PutSession(sessionID uint32, sesh *mux.Session) { u.sessionsM.Lock() u.sessions[sessionID] = sesh u.sessionsM.Unlock() } -func (u *user) DelSession(sessionID uint32) { +func (u *User) DelSession(sessionID uint32) { u.sessionsM.Lock() delete(u.sessions, sessionID) if len(u.sessions) == 0 { @@ -74,13 +71,15 @@ func (u *user) DelSession(sessionID uint32) { u.sessionsM.Unlock() } -func (u *user) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session) { - log.Printf("getting sessionID %v\n", sessionID) - if sesh = u.GetSession(sessionID); sesh != nil { - return +func (u *User) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool) { + u.sessionsM.Lock() + defer u.sessionsM.Unlock() + if sesh = u.sessions[sessionID]; sesh != nil { + return sesh, true } else { + log.Printf("Creating session %v\n", sessionID) sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) - u.PutSession(sessionID, sesh) - return + u.sessions[sessionID] = sesh + return sesh, false } } diff --git a/internal/server/usermanager/userpanel.go b/internal/server/usermanager/userpanel.go index 8408c37..67378aa 100644 --- a/internal/server/usermanager/userpanel.go +++ b/internal/server/usermanager/userpanel.go @@ -3,15 +3,16 @@ package usermanager import ( "encoding/binary" "errors" - "github.com/boltdb/bolt" "sync" + + "github.com/boltdb/bolt" ) type Userpanel struct { db *bolt.DB activeUsersM sync.RWMutex - activeUsers map[[32]byte]*user + activeUsers map[[32]byte]*User } func MakeUserpanel(dbPath string) (*Userpanel, error) { @@ -21,7 +22,7 @@ func MakeUserpanel(dbPath string) (*Userpanel, error) { } up := &Userpanel{ db: db, - activeUsers: make(map[[32]byte]*user), + activeUsers: make(map[[32]byte]*User), } return up, nil } @@ -30,13 +31,12 @@ var ErrUserNotFound = errors.New("User does not exist in memory or db") // GetUser is used to retrieve a user if s/he is active, or to retrieve the user's infor // from the db and mark it as an active user -func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*user, error) { - up.activeUsersM.RLock() +func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*User, error) { + up.activeUsersM.Lock() + defer up.activeUsersM.Unlock() if user, ok := up.activeUsers[UID]; ok { - up.activeUsersM.RUnlock() return user, nil } - up.activeUsersM.RUnlock() var sessionsCap uint32 var upRate, downRate, upCredit, downCredit int64 @@ -57,9 +57,7 @@ func (up *Userpanel) GetAndActivateUser(UID [32]byte) (*user, error) { } // TODO: put all of these parameters in a struct instead u := MakeUser(up, UID, sessionsCap, upRate, downRate, upCredit, downCredit) - up.activeUsersM.Lock() up.activeUsers[UID] = u - up.activeUsersM.Unlock() return u, nil } @@ -136,7 +134,7 @@ func (up *Userpanel) delActiveUser(UID [32]byte) { up.activeUsersM.Unlock() } -func (up *Userpanel) getActiveUser(UID [32]byte) *user { +func (up *Userpanel) getActiveUser(UID [32]byte) *User { up.activeUsersM.RLock() defer up.activeUsersM.RUnlock() return up.activeUsers[UID]