diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index 049a00c..210bce4 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -30,16 +30,6 @@ const testPayloadLen = 1024 const obfsBufLen = testPayloadLen * 2 func TestRecvDataFromRemote(t *testing.T) { - testPayload := make([]byte, testPayloadLen) - rand.Read(testPayload) - f := &Frame{ - 1, - 0, - 0, - testPayload, - } - obfsBuf := make([]byte, obfsBufLen) - var sessionKey [32]byte rand.Read(sessionKey[:]) @@ -54,31 +44,164 @@ func TestRecvDataFromRemote(t *testing.T) { if err != nil { t.Fatalf("failed to make obfuscator: %v", err) } - sesh := MakeSession(0, seshConfig) - n, err := sesh.obfuscate(f, obfsBuf, 0) - if err != nil { - t.Error(err) - return - } - err = sesh.recvDataFromRemote(obfsBuf[:n]) - if err != nil { - t.Error(err) - return - } - stream, err := sesh.Accept() - if err != nil { - t.Error(err) - return - } + t.Run("initial frame", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) - resultPayload := make([]byte, testPayloadLen) - _, err = stream.Read(resultPayload) - if err != nil { - t.Error(err) - return - } - if !bytes.Equal(testPayload, resultPayload) { - t.Errorf("Expecting %x, got %x", testPayload, resultPayload) + resultPayload := make([]byte, testPayloadLen) + _, err = stream.Read(resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + t.Run("two frames in order", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + + f.Seq += 1 + rand.Read(f.Payload) + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + t.Run("two frames in order", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + make([]byte, testPayloadLen), + } + rand.Read(f.Payload) + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + + f.Seq += 1 + rand.Read(f.Payload) + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + + assert.EqualValues(t, f.Payload, resultPayload) + }) + + if seshType == "ordered" { + t.Run("frames out of order", func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + obfsBuf := make([]byte, obfsBufLen) + f := Frame{ + 1, + 0, + 0, + nil, + } + + // First frame + seq0 := make([]byte, testPayloadLen) + rand.Read(seq0) + f.Seq = 0 + f.Payload = seq0 + n, err := sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Third frame + seq2 := make([]byte, testPayloadLen) + rand.Read(seq2) + f.Seq = 2 + f.Payload = seq2 + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Second frame + seq1 := make([]byte, testPayloadLen) + rand.Read(seq1) + f.Seq = 1 + f.Payload = seq1 + n, err = sesh.obfuscate(&f, obfsBuf, 0) + assert.NoError(t, err) + err = sesh.recvDataFromRemote(obfsBuf[:n]) + assert.NoError(t, err) + + // Expect things to receive in order + stream, err := sesh.Accept() + assert.NoError(t, err) + + resultPayload := make([]byte, testPayloadLen) + + // First + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq0, resultPayload) + + // Second + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq1, resultPayload) + + // Third + _, err = io.ReadFull(stream, resultPayload) + assert.NoError(t, err) + assert.EqualValues(t, seq2, resultPayload) + }) } }) }