Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2020-12-12 22:06:56 +02:00
commit 1e6a348d4e
3 changed files with 103 additions and 145 deletions

View File

@ -63,22 +63,22 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
return clientSession, serverSession, paris return clientSession, serverSession, paris
} }
func runEchoTest(t *testing.T, streams []*Stream) { func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
const testDataLen = 16384
var wg sync.WaitGroup var wg sync.WaitGroup
for _, stream := range streams { for _, conn := range conns {
wg.Add(1) wg.Add(1)
go func(stream *Stream) { go func(conn net.Conn) {
testDataLen := rand.Intn(maxMsgLen)
testData := make([]byte, testDataLen) testData := make([]byte, testDataLen)
rand.Read(testData) rand.Read(testData)
n, err := stream.Write(testData) n, err := conn.Write(testData)
if n != testDataLen { if n != testDataLen {
t.Fatalf("written only %v, err %v", n, err) t.Fatalf("written only %v, err %v", n, err)
} }
recvBuf := make([]byte, testDataLen) recvBuf := make([]byte, testDataLen)
_, err = io.ReadFull(stream, recvBuf) _, err = io.ReadFull(conn, recvBuf)
if err != nil { if err != nil {
t.Fatalf("failed to read back: %v", err) t.Fatalf("failed to read back: %v", err)
} }
@ -87,7 +87,7 @@ func runEchoTest(t *testing.T, streams []*Stream) {
t.Fatalf("echoed data not correct") t.Fatalf("echoed data not correct")
} }
wg.Done() wg.Done()
}(stream) }(conn)
} }
wg.Wait() wg.Wait()
} }
@ -95,11 +95,12 @@ func runEchoTest(t *testing.T, streams []*Stream) {
func TestMultiplex(t *testing.T) { func TestMultiplex(t *testing.T) {
const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numStreams = 2000 // -race option limits the number of goroutines to 8192
const numConns = 4 const numConns = 4
const maxMsgLen = 16384
clientSession, serverSession, _ := makeSessionPair(numConns) clientSession, serverSession, _ := makeSessionPair(numConns)
go serveEcho(serverSession) go serveEcho(serverSession)
streams := make([]*Stream, numStreams) streams := make([]net.Conn, numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
stream, err := clientSession.OpenStream() stream, err := clientSession.OpenStream()
if err != nil { if err != nil {
@ -109,7 +110,7 @@ func TestMultiplex(t *testing.T) {
} }
//test echo //test echo
runEchoTest(t, streams) runEchoTest(t, streams, maxMsgLen)
if clientSession.streamCount() != numStreams { if clientSession.streamCount() != numStreams {
t.Errorf("client stream count is wrong: %v", clientSession.streamCount()) t.Errorf("client stream count is wrong: %v", clientSession.streamCount())
} }

View File

@ -253,8 +253,7 @@ func (sesh *Session) TerminalMsg() string {
} }
} }
func (sesh *Session) passiveClose() error { func (sesh *Session) closeSession(closeSwitchboard bool) error {
log.Debugf("attempting to passively close session %v", sesh.id)
if atomic.SwapUint32(&sesh.closed, 1) == 1 { if atomic.SwapUint32(&sesh.closed, 1) == 1 {
log.Debugf("session %v has already been closed", sesh.id) log.Debugf("session %v has already been closed", sesh.id)
return errRepeatSessionClosing return errRepeatSessionClosing
@ -273,7 +272,18 @@ func (sesh *Session) passiveClose() error {
return true return true
}) })
sesh.sb.closeAll() if closeSwitchboard {
sesh.sb.closeAll()
}
return nil
}
func (sesh *Session) passiveClose() error {
log.Debugf("attempting to passively close session %v", sesh.id)
err := sesh.closeSession(true)
if err != nil {
return err
}
log.Debugf("session %v closed gracefully", sesh.id) log.Debugf("session %v closed gracefully", sesh.id)
return nil return nil
} }
@ -288,25 +298,10 @@ func genRandomPadding() []byte {
func (sesh *Session) Close() error { func (sesh *Session) Close() error {
log.Debugf("attempting to actively close session %v", sesh.id) log.Debugf("attempting to actively close session %v", sesh.id)
if atomic.SwapUint32(&sesh.closed, 1) == 1 { err := sesh.closeSession(false)
log.Debugf("session %v has already been closed", sesh.id) if err != nil {
return errRepeatSessionClosing return err
} }
sesh.acceptCh <- nil
// close all streams
sesh.streams.Range(func(key, streamI interface{}) bool {
if streamI == nil {
return true
}
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1)
_ = stream.recvBuf.Close() // will not block
sesh.streams.Delete(key)
sesh.streamCountDecr()
return true
})
// we send a notice frame telling remote to close the session // we send a notice frame telling remote to close the session
pad := genRandomPadding() pad := genRandomPadding()
f := &Frame{ f := &Frame{

View File

@ -17,8 +17,10 @@ var seshConfigUnordered = SessionConfig{
Unordered: true, Unordered: true,
} }
const testPayloadLen = 1024
const obfsBufLen = testPayloadLen * 2
func TestRecvDataFromRemote(t *testing.T) { func TestRecvDataFromRemote(t *testing.T) {
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
f := &Frame{ f := &Frame{
@ -27,126 +29,88 @@ func TestRecvDataFromRemote(t *testing.T) {
0, 0,
testPayload, testPayload,
} }
obfsBuf := make([]byte, 17000) obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
t.Run("plain ordered", func(t *testing.T) {
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n]) MakeObfuscatorUnwrap := func(method byte, sessionKey [32]byte) Obfuscator {
ret, err := MakeObfuscator(method, sessionKey)
if err != nil { if err != nil {
t.Error(err) t.Fatalf("failed to make an obfuscator: %v", err)
return
}
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
} }
return ret
}
resultPayload := make([]byte, testPayloadLen) sessionTypes := []struct {
_, err = stream.Read(resultPayload) name string
if err != nil { config SessionConfig
t.Error(err) }{
return {"ordered",
} SessionConfig{}},
if !bytes.Equal(testPayload, resultPayload) { {"unordered",
t.Errorf("Expecting %x, got %x", testPayload, resultPayload) SessionConfig{Unordered: true}},
} }
})
t.Run("aes-gcm ordered", func(t *testing.T) {
obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n]) encryptionMethods := []struct {
if err != nil { name string
t.Error(err) obfuscator Obfuscator
return }{
} {
stream, err := sesh.Accept() "plain",
if err != nil { MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
t.Error(err) },
return {
} "aes-gcm",
MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
},
{
"chacha20-poly1305",
MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
},
}
resultPayload := make([]byte, testPayloadLen) for _, st := range sessionTypes {
_, err = stream.Read(resultPayload) t.Run(st.name, func(t *testing.T) {
if err != nil { for _, em := range encryptionMethods {
t.Error(err) t.Run(em.name, func(t *testing.T) {
return st.config.Obfuscator = em.obfuscator
} sesh := MakeSession(0, st.config)
if !bytes.Equal(testPayload, resultPayload) { n, err := sesh.Obfs(f, obfsBuf, 0)
t.Errorf("Expecting %x, got %x", testPayload, resultPayload) if err != nil {
} t.Error(err)
}) return
t.Run("chacha20-poly1305 ordered", func(t *testing.T) { }
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) err = sesh.recvDataFromRemote(obfsBuf[:n])
seshConfigOrdered.Obfuscator = obfuscator if err != nil {
sesh := MakeSession(0, seshConfigOrdered) t.Error(err)
n, _ := sesh.Obfs(f, obfsBuf, 0) return
}
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
}
err := sesh.recvDataFromRemote(obfsBuf[:n]) resultPayload := make([]byte, testPayloadLen)
if err != nil { _, err = stream.Read(resultPayload)
t.Error(err) if err != nil {
return t.Error(err)
} return
stream, err := sesh.Accept() }
if err != nil { if !bytes.Equal(testPayload, resultPayload) {
t.Error(err) t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
return }
} })
}
resultPayload := make([]byte, testPayloadLen) })
_, err = stream.Read(resultPayload) }
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
}
})
t.Run("plain unordered", func(t *testing.T) {
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigUnordered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
err := sesh.recvDataFromRemote(obfsBuf[:n])
if err != nil {
t.Error(err)
return
}
stream, err := sesh.Accept()
if err != nil {
t.Error(err)
return
}
resultPayload := make([]byte, testPayloadLen)
_, err = stream.Read(resultPayload)
if err != nil {
t.Error(err)
return
}
if !bytes.Equal(testPayload, resultPayload) {
t.Errorf("Expecting %x, got %x", testPayload, resultPayload)
}
})
} }
func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
obfsBuf := make([]byte, 17000) obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
@ -273,10 +237,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
// Tests for when the closing frame of a stream is received first before any data frame // Tests for when the closing frame of a stream is received first before any data frame
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
obfsBuf := make([]byte, 17000) obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
@ -354,7 +317,7 @@ func TestParallelStreams(t *testing.T) {
} }
} }
numOfTests := 5000 const numOfTests = 5000
tests := make([]struct { tests := make([]struct {
name string name string
frame *Frame frame *Frame
@ -368,11 +331,11 @@ func TestParallelStreams(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
wg.Add(1) wg.Add(1)
go func(frame *Frame) { go func(frame *Frame) {
data := make([]byte, 1000) obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(frame, data, 0) n, _ := sesh.Obfs(frame, obfsBuf, 0)
data = data[0:n] obfsBuf = obfsBuf[0:n]
err := sesh.recvDataFromRemote(data) err := sesh.recvDataFromRemote(obfsBuf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -452,7 +415,6 @@ func TestSession_timeoutAfter(t *testing.T) {
} }
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
testPayloadLen := 1024
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
f := &Frame{ f := &Frame{
@ -461,7 +423,7 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
0, 0,
testPayload, testPayload,
} }
obfsBuf := make([]byte, 17000) obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])