From cb672a99dec217d108bd8655de60ef1bbb5d427d Mon Sep 17 00:00:00 2001 From: Qian Wang Date: Mon, 12 Aug 2019 00:22:15 +0100 Subject: [PATCH] Refactor session configuration --- cmd/ck-client/ck-client.go | 8 +++- cmd/ck-server/ck-server.go | 14 +++++- internal/multiplex/session.go | 33 +++++++++----- internal/multiplex/session_test.go | 15 +++++-- internal/multiplex/stream_test.go | 61 +++++++++++++++++++++++++- internal/multiplex/switchboard.go | 2 +- internal/multiplex/switchboard_test.go | 8 +++- internal/server/activeuser.go | 6 +-- internal/server/activeuser_test.go | 13 ++---- 9 files changed, 126 insertions(+), 34 deletions(-) diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 1789519..3e5d04d 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -73,7 +73,13 @@ func makeSession(sta *client.State) *mux.Session { if err != nil { log.Fatal(err) } - sesh := mux.MakeSession(sta.SessionID, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS) + + seshConfig := &mux.SessionConfig{ + Obfuscator: obfuscator, + Valve: nil, + UnitRead: util.ReadTLS, + } + sesh := mux.MakeSession(sta.SessionID, seshConfig) for i := 0; i < sta.NumConn; i++ { conn := <-connsCh diff --git a/cmd/ck-server/ck-server.go b/cmd/ck-server/ck-server.go index 9f14b24..1fafda4 100644 --- a/cmd/ck-server/ck-server.go +++ b/cmd/ck-server/ck-server.go @@ -84,7 +84,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) { return } log.Trace("finished handshake") - sesh := mux.MakeSession(0, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS) + seshConfig := &mux.SessionConfig{ + Obfuscator: obfuscator, + Valve: nil, + UnitRead: util.ReadTLS, + } + sesh := mux.MakeSession(0, seshConfig) sesh.AddConnection(conn) //TODO: Router could be nil in cnc mode log.WithField("remoteAddr", conn.RemoteAddr()).Info("New admin session") @@ -111,7 +116,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) { return } - sesh, existing, err := user.GetSession(sessionID, obfuscator, util.ReadTLS) + seshConfig := &mux.SessionConfig{ + Obfuscator: obfuscator, + Valve: nil, + UnitRead: util.ReadTLS, + } + sesh, existing, err := user.GetSession(sessionID, seshConfig) if err != nil { user.DeleteSession(sessionID, "") log.Error(err) diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 025cd70..dcc36c3 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -25,13 +25,19 @@ type Obfuscator struct { SessionKey []byte } +type SessionConfig struct { + *Obfuscator + + Valve + + // This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain + UnitRead func(net.Conn, []byte) (int, error) +} + type Session struct { id uint32 - *Obfuscator - - // This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain - unitRead func(net.Conn, []byte) (int, error) + *SessionConfig // atomic nextStreamID uint32 @@ -52,17 +58,20 @@ type Session struct { terminalMsg atomic.Value } -func MakeSession(id uint32, valve Valve, obfuscator *Obfuscator, unitReader func(net.Conn, []byte) (int, error)) *Session { +func MakeSession(id uint32, config *SessionConfig) *Session { sesh := &Session{ - id: id, - unitRead: unitReader, - nextStreamID: 1, - Obfuscator: obfuscator, - streams: make(map[uint32]*Stream), - acceptCh: make(chan *Stream, acceptBacklog), + id: id, + SessionConfig: config, + nextStreamID: 1, + streams: make(map[uint32]*Stream), + acceptCh: make(chan *Stream, acceptBacklog), } sesh.addrs.Store([]net.Addr{nil, nil}) - sesh.sb = makeSwitchboard(sesh, valve) + + if config.Valve == nil { + config.Valve = UNLIMITED_VALVE + } + sesh.sb = makeSwitchboard(sesh, config.Valve) go sesh.timeoutAfter(30 * time.Second) return sesh } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b7b947a..3086a3d 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -6,6 +6,12 @@ import ( "testing" ) +var seshConfig = &SessionConfig{ + Obfuscator: nil, + Valve: nil, + UnitRead: util.ReadTLS, +} + func BenchmarkRecvDataFromRemote(b *testing.B) { testPayload := make([]byte, 1024) rand.Read(testPayload) @@ -22,7 +28,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { b.Run("plain", func(b *testing.B) { obfuscator, _ := GenerateObfs(0x00, sessionKey) - sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, _ := sesh.Obfs(f, obfsBuf) b.ResetTimer() @@ -34,7 +41,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { b.Run("aes-gcm", func(b *testing.B) { obfuscator, _ := GenerateObfs(0x01, sessionKey) - sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, _ := sesh.Obfs(f, obfsBuf) b.ResetTimer() @@ -46,7 +54,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { b.Run("chacha20-poly1305", func(b *testing.B) { obfuscator, _ := GenerateObfs(0x02, sessionKey) - sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, _ := sesh.Obfs(f, obfsBuf) b.ResetTimer() diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index d7737d2..fea1360 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -15,7 +15,13 @@ func setupSesh() *Session { sessionKey := make([]byte, 32) rand.Read(sessionKey) obfuscator, _ := GenerateObfs(0x00, sessionKey) - return MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) + + seshConfig := &SessionConfig{ + Obfuscator: obfuscator, + Valve: nil, + UnitRead: util.ReadTLS, + } + return MakeSession(0, seshConfig) } type blackhole struct { @@ -63,6 +69,59 @@ func BenchmarkStream_Write(b *testing.B) { } } +func BenchmarkStream_Read(b *testing.B) { + sesh := setupSesh() + const PAYLOAD_LEN = 1000 + testPayload := make([]byte, PAYLOAD_LEN) + rand.Read(testPayload) + + f := &Frame{ + 1, + 0, + 0, + testPayload, + } + + obfsBuf := make([]byte, 17000) + + l, _ := net.Listen("tcp", "127.0.0.1:0") + go func() { + // potentially bottlenecked here rather than the actual stream read throughput + conn, _ := net.Dial("tcp", l.Addr().String()) + for { + i, _ := sesh.Obfs(f, obfsBuf) + f.Seq += 1 + _, err := conn.Write(obfsBuf[:i]) + if err != nil { + b.Error("cannot write to connection", err) + } + } + }() + conn, _ := l.Accept() + + sesh.AddConnection(conn) + stream, err := sesh.Accept() + if err != nil { + b.Error("failed to accept stream", err) + } + + //time.Sleep(5*time.Second) // wait for buffer to fill up + + readBuf := make([]byte, PAYLOAD_LEN) + b.ResetTimer() + for j := 0; j < b.N; j++ { + n, err := stream.Read(readBuf) + if !bytes.Equal(readBuf, testPayload) { + b.Error("paylod not equal") + } + b.SetBytes(int64(n)) + if err != nil { + b.Error(err) + } + } + +} + func TestStream_Read(t *testing.T) { sesh := setupSesh() testPayload := []byte{42, 42, 42} diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index a67e3d2..095db96 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -112,7 +112,7 @@ func (sb *switchboard) closeAll() { func (sb *switchboard) deplex(connId uint32, conn net.Conn) { buf := make([]byte, 20480) for { - n, err := sb.session.unitRead(conn, buf) + n, err := sb.session.UnitRead(conn, buf) sb.rxWait(n) sb.Valve.AddRx(int64(n)) if err != nil { diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 468c191..c5e36f7 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -1,13 +1,17 @@ package multiplex import ( - "github.com/cbeuw/Cloak/internal/util" "math/rand" "testing" ) func BenchmarkSwitchboard_Send(b *testing.B) { - sesh := MakeSession(0, UNLIMITED_VALVE, nil, util.ReadTLS) + seshConfig := &SessionConfig{ + Obfuscator: nil, + Valve: nil, + UnitRead: nil, + } + sesh := MakeSession(0, seshConfig) sb := makeSwitchboard(sesh, UNLIMITED_VALVE) hole := newBlackHole() sb.addConn(hole) diff --git a/internal/server/activeuser.go b/internal/server/activeuser.go index d741899..4d0b601 100644 --- a/internal/server/activeuser.go +++ b/internal/server/activeuser.go @@ -1,7 +1,6 @@ package server import ( - "net" "sync" mux "github.com/cbeuw/Cloak/internal/multiplex" @@ -34,7 +33,7 @@ func (u *ActiveUser) DeleteSession(sessionID uint32, reason string) { u.sessionsM.Unlock() } -func (u *ActiveUser) GetSession(sessionID uint32, obfuscator *mux.Obfuscator, unitReader func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { +func (u *ActiveUser) GetSession(sessionID uint32, config *mux.SessionConfig) (sesh *mux.Session, existing bool, err error) { u.sessionsM.Lock() defer u.sessionsM.Unlock() if sesh = u.sessions[sessionID]; sesh != nil { @@ -46,7 +45,8 @@ func (u *ActiveUser) GetSession(sessionID uint32, obfuscator *mux.Obfuscator, un return nil, false, err } } - sesh = mux.MakeSession(sessionID, u.valve, obfuscator, unitReader) + config.Valve = u.valve + sesh = mux.MakeSession(sessionID, config) u.sessions[sessionID] = sesh return sesh, false, nil } diff --git a/internal/server/activeuser_test.go b/internal/server/activeuser_test.go index 1373a01..b46c3cc 100644 --- a/internal/server/activeuser_test.go +++ b/internal/server/activeuser_test.go @@ -16,16 +16,11 @@ func TestActiveUser_Bypass(t *testing.T) { panel := MakeUserPanel(manager) UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==") user, _ := panel.GetBypassUser(UID) - obfuscator := &mux.Obfuscator{ - nil, - nil, - nil, - } var sesh0 *mux.Session var existing bool var sesh1 *mux.Session t.Run("get first session", func(t *testing.T) { - sesh0, existing, err = user.GetSession(0, obfuscator, nil) + sesh0, existing, err = user.GetSession(0, &mux.SessionConfig{}) if err != nil { t.Error(err) } @@ -37,7 +32,7 @@ func TestActiveUser_Bypass(t *testing.T) { } }) t.Run("get first session again", func(t *testing.T) { - seshx, existing, err := user.GetSession(0, obfuscator, nil) + seshx, existing, err := user.GetSession(0, &mux.SessionConfig{}) if err != nil { t.Error(err) } @@ -52,7 +47,7 @@ func TestActiveUser_Bypass(t *testing.T) { } }) t.Run("get second session", func(t *testing.T) { - sesh1, existing, err = user.GetSession(1, obfuscator, nil) + sesh1, existing, err = user.GetSession(1, &mux.SessionConfig{}) if err != nil { t.Error(err) } @@ -87,7 +82,7 @@ func TestActiveUser_Bypass(t *testing.T) { } }) t.Run("get session again after termination", func(t *testing.T) { - seshx, existing, err := user.GetSession(0, obfuscator, nil) + seshx, existing, err := user.GetSession(0, &mux.SessionConfig{}) if err != nil { t.Error(err) }