Refactor session configuration

This commit is contained in:
Qian Wang 2019-08-12 00:22:15 +01:00
parent c3d4057315
commit cb672a99de
9 changed files with 126 additions and 34 deletions

View File

@ -73,7 +73,13 @@ func makeSession(sta *client.State) *mux.Session {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
sesh := mux.MakeSession(sta.SessionID, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS)
seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
}
sesh := mux.MakeSession(sta.SessionID, seshConfig)
for i := 0; i < sta.NumConn; i++ { for i := 0; i < sta.NumConn; i++ {
conn := <-connsCh conn := <-connsCh

View File

@ -84,7 +84,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
return return
} }
log.Trace("finished handshake") log.Trace("finished handshake")
sesh := mux.MakeSession(0, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS) seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
}
sesh := mux.MakeSession(0, seshConfig)
sesh.AddConnection(conn) sesh.AddConnection(conn)
//TODO: Router could be nil in cnc mode //TODO: Router could be nil in cnc mode
log.WithField("remoteAddr", conn.RemoteAddr()).Info("New admin session") log.WithField("remoteAddr", conn.RemoteAddr()).Info("New admin session")
@ -111,7 +116,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
return return
} }
sesh, existing, err := user.GetSession(sessionID, obfuscator, util.ReadTLS) seshConfig := &mux.SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
}
sesh, existing, err := user.GetSession(sessionID, seshConfig)
if err != nil { if err != nil {
user.DeleteSession(sessionID, "") user.DeleteSession(sessionID, "")
log.Error(err) log.Error(err)

View File

@ -25,13 +25,19 @@ type Obfuscator struct {
SessionKey []byte SessionKey []byte
} }
type SessionConfig struct {
*Obfuscator
Valve
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
UnitRead func(net.Conn, []byte) (int, error)
}
type Session struct { type Session struct {
id uint32 id uint32
*Obfuscator *SessionConfig
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
unitRead func(net.Conn, []byte) (int, error)
// atomic // atomic
nextStreamID uint32 nextStreamID uint32
@ -52,17 +58,20 @@ type Session struct {
terminalMsg atomic.Value terminalMsg atomic.Value
} }
func MakeSession(id uint32, valve Valve, obfuscator *Obfuscator, unitReader func(net.Conn, []byte) (int, error)) *Session { func MakeSession(id uint32, config *SessionConfig) *Session {
sesh := &Session{ sesh := &Session{
id: id, id: id,
unitRead: unitReader, SessionConfig: config,
nextStreamID: 1, nextStreamID: 1,
Obfuscator: obfuscator,
streams: make(map[uint32]*Stream), streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
} }
sesh.addrs.Store([]net.Addr{nil, nil}) sesh.addrs.Store([]net.Addr{nil, nil})
sesh.sb = makeSwitchboard(sesh, valve)
if config.Valve == nil {
config.Valve = UNLIMITED_VALVE
}
sesh.sb = makeSwitchboard(sesh, config.Valve)
go sesh.timeoutAfter(30 * time.Second) go sesh.timeoutAfter(30 * time.Second)
return sesh return sesh
} }

View File

@ -6,6 +6,12 @@ import (
"testing" "testing"
) )
var seshConfig = &SessionConfig{
Obfuscator: nil,
Valve: nil,
UnitRead: util.ReadTLS,
}
func BenchmarkRecvDataFromRemote(b *testing.B) { func BenchmarkRecvDataFromRemote(b *testing.B) {
testPayload := make([]byte, 1024) testPayload := make([]byte, 1024)
rand.Read(testPayload) rand.Read(testPayload)
@ -22,7 +28,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
obfuscator, _ := GenerateObfs(0x00, sessionKey) obfuscator, _ := GenerateObfs(0x00, sessionKey)
sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, _ := sesh.Obfs(f, obfsBuf) n, _ := sesh.Obfs(f, obfsBuf)
b.ResetTimer() b.ResetTimer()
@ -34,7 +41,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
b.Run("aes-gcm", func(b *testing.B) { b.Run("aes-gcm", func(b *testing.B) {
obfuscator, _ := GenerateObfs(0x01, sessionKey) obfuscator, _ := GenerateObfs(0x01, sessionKey)
sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, _ := sesh.Obfs(f, obfsBuf) n, _ := sesh.Obfs(f, obfsBuf)
b.ResetTimer() b.ResetTimer()
@ -46,7 +54,8 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
b.Run("chacha20-poly1305", func(b *testing.B) { b.Run("chacha20-poly1305", func(b *testing.B) {
obfuscator, _ := GenerateObfs(0x02, sessionKey) obfuscator, _ := GenerateObfs(0x02, sessionKey)
sesh := MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS) seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
n, _ := sesh.Obfs(f, obfsBuf) n, _ := sesh.Obfs(f, obfsBuf)
b.ResetTimer() b.ResetTimer()

View File

@ -15,7 +15,13 @@ func setupSesh() *Session {
sessionKey := make([]byte, 32) sessionKey := make([]byte, 32)
rand.Read(sessionKey) rand.Read(sessionKey)
obfuscator, _ := GenerateObfs(0x00, sessionKey) obfuscator, _ := GenerateObfs(0x00, sessionKey)
return MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS)
seshConfig := &SessionConfig{
Obfuscator: obfuscator,
Valve: nil,
UnitRead: util.ReadTLS,
}
return MakeSession(0, seshConfig)
} }
type blackhole struct { type blackhole struct {
@ -63,6 +69,59 @@ func BenchmarkStream_Write(b *testing.B) {
} }
} }
func BenchmarkStream_Read(b *testing.B) {
sesh := setupSesh()
const PAYLOAD_LEN = 1000
testPayload := make([]byte, PAYLOAD_LEN)
rand.Read(testPayload)
f := &Frame{
1,
0,
0,
testPayload,
}
obfsBuf := make([]byte, 17000)
l, _ := net.Listen("tcp", "127.0.0.1:0")
go func() {
// potentially bottlenecked here rather than the actual stream read throughput
conn, _ := net.Dial("tcp", l.Addr().String())
for {
i, _ := sesh.Obfs(f, obfsBuf)
f.Seq += 1
_, err := conn.Write(obfsBuf[:i])
if err != nil {
b.Error("cannot write to connection", err)
}
}
}()
conn, _ := l.Accept()
sesh.AddConnection(conn)
stream, err := sesh.Accept()
if err != nil {
b.Error("failed to accept stream", err)
}
//time.Sleep(5*time.Second) // wait for buffer to fill up
readBuf := make([]byte, PAYLOAD_LEN)
b.ResetTimer()
for j := 0; j < b.N; j++ {
n, err := stream.Read(readBuf)
if !bytes.Equal(readBuf, testPayload) {
b.Error("paylod not equal")
}
b.SetBytes(int64(n))
if err != nil {
b.Error(err)
}
}
}
func TestStream_Read(t *testing.T) { func TestStream_Read(t *testing.T) {
sesh := setupSesh() sesh := setupSesh()
testPayload := []byte{42, 42, 42} testPayload := []byte{42, 42, 42}

View File

@ -112,7 +112,7 @@ func (sb *switchboard) closeAll() {
func (sb *switchboard) deplex(connId uint32, conn net.Conn) { func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
buf := make([]byte, 20480) buf := make([]byte, 20480)
for { for {
n, err := sb.session.unitRead(conn, buf) n, err := sb.session.UnitRead(conn, buf)
sb.rxWait(n) sb.rxWait(n)
sb.Valve.AddRx(int64(n)) sb.Valve.AddRx(int64(n))
if err != nil { if err != nil {

View File

@ -1,13 +1,17 @@
package multiplex package multiplex
import ( import (
"github.com/cbeuw/Cloak/internal/util"
"math/rand" "math/rand"
"testing" "testing"
) )
func BenchmarkSwitchboard_Send(b *testing.B) { func BenchmarkSwitchboard_Send(b *testing.B) {
sesh := MakeSession(0, UNLIMITED_VALVE, nil, util.ReadTLS) seshConfig := &SessionConfig{
Obfuscator: nil,
Valve: nil,
UnitRead: nil,
}
sesh := MakeSession(0, seshConfig)
sb := makeSwitchboard(sesh, UNLIMITED_VALVE) sb := makeSwitchboard(sesh, UNLIMITED_VALVE)
hole := newBlackHole() hole := newBlackHole()
sb.addConn(hole) sb.addConn(hole)

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"net"
"sync" "sync"
mux "github.com/cbeuw/Cloak/internal/multiplex" mux "github.com/cbeuw/Cloak/internal/multiplex"
@ -34,7 +33,7 @@ func (u *ActiveUser) DeleteSession(sessionID uint32, reason string) {
u.sessionsM.Unlock() u.sessionsM.Unlock()
} }
func (u *ActiveUser) GetSession(sessionID uint32, obfuscator *mux.Obfuscator, unitReader func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { func (u *ActiveUser) GetSession(sessionID uint32, config *mux.SessionConfig) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock() u.sessionsM.Lock()
defer u.sessionsM.Unlock() defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil { if sesh = u.sessions[sessionID]; sesh != nil {
@ -46,7 +45,8 @@ func (u *ActiveUser) GetSession(sessionID uint32, obfuscator *mux.Obfuscator, un
return nil, false, err return nil, false, err
} }
} }
sesh = mux.MakeSession(sessionID, u.valve, obfuscator, unitReader) config.Valve = u.valve
sesh = mux.MakeSession(sessionID, config)
u.sessions[sessionID] = sesh u.sessions[sessionID] = sesh
return sesh, false, nil return sesh, false, nil
} }

View File

@ -16,16 +16,11 @@ func TestActiveUser_Bypass(t *testing.T) {
panel := MakeUserPanel(manager) panel := MakeUserPanel(manager)
UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==") UID, _ := base64.StdEncoding.DecodeString("u97xvcc5YoQA8obCyt9q/w==")
user, _ := panel.GetBypassUser(UID) user, _ := panel.GetBypassUser(UID)
obfuscator := &mux.Obfuscator{
nil,
nil,
nil,
}
var sesh0 *mux.Session var sesh0 *mux.Session
var existing bool var existing bool
var sesh1 *mux.Session var sesh1 *mux.Session
t.Run("get first session", func(t *testing.T) { t.Run("get first session", func(t *testing.T) {
sesh0, existing, err = user.GetSession(0, obfuscator, nil) sesh0, existing, err = user.GetSession(0, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -37,7 +32,7 @@ func TestActiveUser_Bypass(t *testing.T) {
} }
}) })
t.Run("get first session again", func(t *testing.T) { t.Run("get first session again", func(t *testing.T) {
seshx, existing, err := user.GetSession(0, obfuscator, nil) seshx, existing, err := user.GetSession(0, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -52,7 +47,7 @@ func TestActiveUser_Bypass(t *testing.T) {
} }
}) })
t.Run("get second session", func(t *testing.T) { t.Run("get second session", func(t *testing.T) {
sesh1, existing, err = user.GetSession(1, obfuscator, nil) sesh1, existing, err = user.GetSession(1, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -87,7 +82,7 @@ func TestActiveUser_Bypass(t *testing.T) {
} }
}) })
t.Run("get session again after termination", func(t *testing.T) { t.Run("get session again after termination", func(t *testing.T) {
seshx, existing, err := user.GetSession(0, obfuscator, nil) seshx, existing, err := user.GetSession(0, &mux.SessionConfig{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }