diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6aaa77e..1b59569 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,10 +13,10 @@ jobs: - name: Build run: | export PATH=${PATH}:`go env GOPATH`/bin - v=${{ github.ref }} ./release.sh + v=${GITHUB_REF#refs/*/} ./release.sh - name: Release uses: softprops/action-gh-release@v1 with: - files: ./release/ck-* + files: release/* env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index 9b50dfc..285575f 100644 --- a/README.md +++ b/README.md @@ -103,15 +103,13 @@ Example: `PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64. -`AdminUID` is the UID of the admin user in base64. - `BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions -`DatabasePath` is the path to `userinfo.db`. If `userinfo.db` doesn't exist in this directory, Cloak will create one -automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as -/ (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you -leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so -and will raise an error. See Issue #13.** +`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`. + +`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will +create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`. +This field also has no effect if `AdminUID` isn't a valid UID or is empty. `KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the upstream proxy server. Zero or negative value disables it. Default is 0 (disabled). @@ -184,6 +182,8 @@ Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.jso ##### Users subject to bandwidth and credit controls +0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db` + in `DatabasePath` (Cloak will create this file for you if it didn't already exist). 1. On your client, run `ck-client -s -l -a -c ` to enter admin mode 2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data diff --git a/internal/client/auth.go b/internal/client/auth.go index 939a34d..4925541 100644 --- a/internal/client/auth.go +++ b/internal/client/auth.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/ecdh" + log "github.com/sirupsen/logrus" ) const ( @@ -26,7 +27,10 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh | 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes | +----------+----------------+---------------------+-------------+--------------+--------+------------+ */ - ephPv, ephPub, _ := ecdh.GenerateKey(authInfo.WorldState.Rand) + ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand) + if err != nil { + log.Panicf("failed to generate ephemeral key pair: %v", err) + } copy(ret.randPubKey[:], ecdh.Marshal(ephPub)) plaintext := make([]byte, 48) @@ -40,7 +44,11 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh plaintext[41] |= UNORDERED_FLAG } - copy(sharedSecret[:], ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)) + secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey) + if err != nil { + log.Panicf("error in generating shared secret: %v", err) + } + copy(sharedSecret[:], secret) ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext) copy(ret.ciphertextWithTag[:], ciphertextWithTag[:]) return diff --git a/internal/ecdh/curve25519.go b/internal/ecdh/curve25519.go index 94d066b..5744c5e 100644 --- a/internal/ecdh/curve25519.go +++ b/internal/ecdh/curve25519.go @@ -68,13 +68,11 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) { return &pub, true } -func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte { - var priv, pub, secret *[32]byte +func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) { + var priv, pub *[32]byte priv = privKey.(*[32]byte) pub = pubKey.(*[32]byte) - secret = new([32]byte) - curve25519.ScalarMult(secret, priv, pub) - return secret[:] + return curve25519.X25519(priv[:], pub[:]) } diff --git a/internal/ecdh/curve25519_test.go b/internal/ecdh/curve25519_test.go index 8e9a1c1..39d56ba 100644 --- a/internal/ecdh/curve25519_test.go +++ b/internal/ecdh/curve25519_test.go @@ -90,11 +90,11 @@ func testECDH(t testing.TB) { t.Fatalf("Unmarshal does not work") } - secret1 = GenerateSharedSecret(privKey1, pubKey2) + secret1, err = GenerateSharedSecret(privKey1, pubKey2) if err != nil { t.Error(err) } - secret2 = GenerateSharedSecret(privKey2, pubKey1) + secret2, err = GenerateSharedSecret(privKey2, pubKey1) if err != nil { t.Error(err) } diff --git a/internal/multiplex/datagramBufferedPipe.go b/internal/multiplex/datagramBufferedPipe.go index e1a0462..a7b99e4 100644 --- a/internal/multiplex/datagramBufferedPipe.go +++ b/internal/multiplex/datagramBufferedPipe.go @@ -112,7 +112,7 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) { } } -func (d *datagramBufferedPipe) Write(f Frame) (toBeClosed bool, err error) { +func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) { d.rwCond.L.Lock() defer d.rwCond.L.Unlock() if d.buf == nil { diff --git a/internal/multiplex/datagramBufferedPipe_test.go b/internal/multiplex/datagramBufferedPipe_test.go index 4a5d4e2..6b20f76 100644 --- a/internal/multiplex/datagramBufferedPipe_test.go +++ b/internal/multiplex/datagramBufferedPipe_test.go @@ -10,7 +10,7 @@ func TestDatagramBuffer_RW(t *testing.T) { b := []byte{0x01, 0x02, 0x03} t.Run("simple write", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - _, err := pipe.Write(Frame{Payload: b}) + _, err := pipe.Write(&Frame{Payload: b}) if err != nil { t.Error( "expecting", "nil error", @@ -22,7 +22,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("simple read", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - _, _ = pipe.Write(Frame{Payload: b}) + _, _ = pipe.Write(&Frame{Payload: b}) b2 := make([]byte, len(b)) n, err := pipe.Read(b2) if n != len(b) { @@ -55,7 +55,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) { pipe := NewDatagramBufferedPipe() - toBeClosed, err := pipe.Write(Frame{Closing: closingStream}) + toBeClosed, err := pipe.Write(&Frame{Closing: closingStream}) if !toBeClosed { t.Error("should be to be closed") } @@ -77,7 +77,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) { b := []byte{0x01, 0x02, 0x03} go func() { time.Sleep(readBlockTime) - pipe.Write(Frame{Payload: b}) + pipe.Write(&Frame{Payload: b}) }() b2 := make([]byte, len(b)) n, err := pipe.Read(b2) @@ -110,7 +110,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) { func TestDatagramBuffer_CloseThenRead(t *testing.T) { pipe := NewDatagramBufferedPipe() b := []byte{0x01, 0x02, 0x03} - pipe.Write(Frame{Payload: b}) + pipe.Write(&Frame{Payload: b}) b2 := make([]byte, len(b)) pipe.Close() n, err := pipe.Read(b2) diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index 76344ca..c8c60f4 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -10,7 +10,6 @@ import ( "net" "sync" "testing" - "time" ) func serveEcho(l net.Listener) { @@ -64,21 +63,20 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) { return clientSession, serverSession, paris } -func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { +func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) { var wg sync.WaitGroup for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { - testDataLen := rand.Intn(maxMsgLen) - testData := make([]byte, testDataLen) + testData := make([]byte, msgLen) rand.Read(testData) n, err := conn.Write(testData) - if n != testDataLen { + if n != msgLen { t.Fatalf("written only %v, err %v", n, err) } - recvBuf := make([]byte, testDataLen) + recvBuf := make([]byte, msgLen) _, err = io.ReadFull(conn, recvBuf) if err != nil { t.Fatalf("failed to read back: %v", err) @@ -96,7 +94,7 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func TestMultiplex(t *testing.T) { const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numConns = 4 - const maxMsgLen = 16384 + const msgLen = 16384 clientSession, serverSession, _ := makeSessionPair(numConns) go serveEcho(serverSession) @@ -111,15 +109,10 @@ func TestMultiplex(t *testing.T) { } //test echo - runEchoTest(t, streams, maxMsgLen) + runEchoTest(t, streams, msgLen) - assert.Eventuallyf(t, func() bool { - return clientSession.streamCount() == numStreams - }, time.Second, 10*time.Millisecond, "client stream count is wrong: %v", clientSession.streamCount()) - - assert.Eventuallyf(t, func() bool { - return serverSession.streamCount() == numStreams - }, time.Second, 10*time.Millisecond, "server stream count is wrong: %v", serverSession.streamCount()) + assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong") + assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong") // close one stream closing, streams := streams[0], streams[1:] diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 0c1f8c6..1379072 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -12,7 +12,7 @@ import ( ) type Obfser func(*Frame, []byte, int) (int, error) -type Deobfser func([]byte) (*Frame, error) +type Deobfser func(*Frame, []byte) error var u32 = binary.BigEndian.Uint32 var u64 = binary.BigEndian.Uint64 @@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { // frame header length + minimum data size (i.e. nonce size of salsa20) const minInputLen = frameHeaderLength + salsa20NonceSize - deobfs := func(in []byte) (*Frame, error) { + deobfs := func(f *Frame, in []byte) error { if len(in) < minInputLen { - return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) + return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) } header := in[:frameHeaderLength] @@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { usefulPayloadLen := len(pldWithOverHead) - int(extraLen) if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { - return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") + return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") } var outputPayload []byte @@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { } else { _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) if err != nil { - return nil, err + return err } outputPayload = pldWithOverHead[:usefulPayloadLen] } - ret := &Frame{ - StreamID: streamID, - Seq: seq, - Closing: closing, - Payload: outputPayload, - } - return ret, nil + f.StreamID = streamID + f.Seq = seq + f.Closing = closing + f.Payload = outputPayload + return nil } return deobfs } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 6cbbb5b..99f4f5f 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) { run := func(obfuscator Obfuscator, ct *testing.T) { obfsBuf := make([]byte, 512) - f := &Frame{} - _testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42))) + _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) if err != nil { @@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) { return } - resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) + var resultFrame Frame + err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i]) if err != nil { ct.Error("failed to deobfs ", err) return @@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.SetBytes(int64(n)) b.ResetTimer() for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("AES128GCM", func(b *testing.B) { @@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("plain", func(b *testing.B) { @@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, nil) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) b.Run("chacha20Poly1305", func(b *testing.B) { @@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) { n, _ := obfs(testFrame, obfsBuf, 0) deobfs := MakeDeobfs(key, payloadCipher) + frame := new(Frame) b.ResetTimer() b.SetBytes(int64(n)) for i := 0; i < b.N; i++ { - deobfs(obfsBuf[:n]) + deobfs(frame, obfsBuf[:n]) } }) } diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 0797daf..63f1f6f 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -15,7 +15,7 @@ type recvBuffer interface { // when the buffer is empty. io.ReadCloser io.WriterTo - Write(Frame) (toBeClosed bool, err error) + Write(*Frame) (toBeClosed bool, err error) SetReadDeadline(time time.Time) // SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing // has been written for a while. After that duration it should return ErrTimeout diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index c808e2b..6f165b0 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -63,7 +63,12 @@ type Session struct { // atomic activeStreamCount uint32 - streams sync.Map + + streamsM sync.Mutex + streams map[uint32]*Stream + + // a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame + recvFramePool sync.Pool // Switchboard manages all connections to remote sb *switchboard @@ -89,6 +94,8 @@ func MakeSession(id uint32, config SessionConfig) *Session { SessionConfig: config, nextStreamID: 1, acceptCh: make(chan *Stream, acceptBacklog), + recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }}, + streams: map[uint32]*Stream{}, } sesh.addrs.Store([]net.Addr{nil, nil}) @@ -145,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) { return nil, errNoMultiplex } stream := makeStream(sesh, id) - sesh.streams.Store(id, stream) + sesh.streamsM.Lock() + sesh.streams[id] = stream + sesh.streamsM.Unlock() sesh.streamCountIncr() log.Tracef("stream %v of session %v opened", id, sesh.id) return stream, nil @@ -165,24 +174,22 @@ func (sesh *Session) Accept() (net.Conn, error) { } func (sesh *Session) closeStream(s *Stream, active bool) error { + // must be holding s.wirtingM on entry if atomic.SwapUint32(&s.closed, 1) == 1 { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } - _ = s.recvBuf.Close() // recvBuf.Close should not return error + _ = s.getRecvBuf().Close() // recvBuf.Close should not return error if active { // Notify remote that this stream is closed padding := genRandomPadding() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingStream, - Payload: padding, - } - s.nextSendSeq++ + s.writingFrame.Closing = closingStream + s.writingFrame.Payload = padding obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) - i, err := sesh.Obfs(f, obfsBuf, 0) + + i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0) + s.writingFrame.Seq++ if err != nil { return err } @@ -190,7 +197,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if err != nil { return err } - log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) + log.Tracef("stream %v actively closed.", s.id) } else { log.Tracef("stream %v passively closed", s.id) } @@ -198,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // We set it as nil to signify that the stream id had existed before. // If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell // if the frame it received was from a new stream or a dying stream whose frame arrived late - sesh.streams.Store(s.id, nil) + sesh.streamsM.Lock() + sesh.streams[s.id] = nil + sesh.streamsM.Unlock() if sesh.streamCountDecr() == 0 { if sesh.Singleplex { return sesh.Close() @@ -214,7 +223,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { // to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new // stream and then writes to the stream buffer func (sesh *Session) recvDataFromRemote(data []byte) error { - frame, err := sesh.Deobfs(data) + frame := sesh.recvFramePool.Get().(*Frame) + defer sesh.recvFramePool.Put(frame) + + err := sesh.Deobfs(frame, data) if err != nil { return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) } @@ -224,19 +236,23 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { return sesh.passiveClose() } - newStream := makeStream(sesh, frame.StreamID) - existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) + sesh.streamsM.Lock() + existingStream, existing := sesh.streams[frame.StreamID] if existing { - if existingStreamI == nil { + sesh.streamsM.Unlock() + if existingStream == nil { // this is when the stream existed before but has since been closed. We do nothing return nil } - return existingStreamI.(*Stream).recvFrame(*frame) + return existingStream.recvFrame(frame) } else { + newStream := makeStream(sesh, frame.StreamID) + sesh.streams[frame.StreamID] = newStream + sesh.streamsM.Unlock() // new stream sesh.streamCountIncr() sesh.acceptCh <- newStream - return newStream.recvFrame(*frame) + return newStream.recvFrame(frame) } } @@ -260,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error { } sesh.acceptCh <- nil - sesh.streams.Range(func(key, streamI interface{}) bool { - if streamI == nil { - return true + sesh.streamsM.Lock() + for id, stream := range sesh.streams { + if stream == nil { + continue } - stream := streamI.(*Stream) atomic.StoreUint32(&stream.closed, 1) - _ = stream.recvBuf.Close() // will not block - sesh.streams.Delete(key) + _ = stream.getRecvBuf().Close() // will not block + delete(sesh.streams, id) sesh.streamCountDecr() - return true - }) + } + sesh.streamsM.Unlock() if closeSwitchboard { sesh.sb.closeAll() diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index b280895..f4b32bb 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -12,10 +12,9 @@ import ( "time" ) -var seshConfigOrdered = SessionConfig{} - -var seshConfigUnordered = SessionConfig{ - Unordered: true, +var seshConfigs = map[string]SessionConfig{ + "ordered": {}, + "unordered": {Unordered: true}, } const testPayloadLen = 1024 @@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) { return ret } - sessionTypes := []struct { - name string - config SessionConfig - }{ - {"ordered", - SessionConfig{}}, - {"unordered", - SessionConfig{Unordered: true}}, + encryptionMethods := map[string]Obfuscator{ + "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), + "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), + "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), } - encryptionMethods := []struct { - name string - obfuscator Obfuscator - }{ - { - "plain", - MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), - }, - { - "aes-gcm", - MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey), - }, - { - "chacha20-poly1305", - MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey), - }, - } - - for _, st := range sessionTypes { - t.Run(st.name, func(t *testing.T) { - for _, em := range encryptionMethods { - t.Run(em.name, func(t *testing.T) { - st.config.Obfuscator = em.obfuscator - sesh := MakeSession(0, st.config) + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + for method, obfuscator := range encryptionMethods { + obfuscator := obfuscator + t.Run(method, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) n, err := sesh.Obfs(f, obfsBuf, 0) if err != nil { t.Error(err) @@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) f1 := &Frame{ 1, @@ -131,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("failed to fetch stream 1 after receiving it") } @@ -151,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving normal frame for stream 2: %v", err) } - s2I, ok := sesh.streams.Load(f2.StreamID) - if s2I == nil || !ok { + sesh.streamsM.Lock() + s2M, ok := sesh.streams[f2.StreamID] + sesh.streamsM.Unlock() + if s2M == nil || !ok { t.Fatal("failed to fetch stream 2 after receiving it") } if sesh.streamCount() != 2 { @@ -171,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1: %v", err) } - s1I, _ := sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ := sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Fatal("stream 1 still exist after receiving stream close") } s1, _ := sesh.Accept() @@ -198,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) { if err != nil { t.Fatalf("receiving stream closing frame for stream 1 %v", err) } - s1I, _ = sesh.streams.Load(f1.StreamID) - if s1I != nil { + sesh.streamsM.Lock() + s1M, _ = sesh.streams[f1.StreamID] + sesh.streamsM.Unlock() + if s1M != nil { t.Error("stream 1 exists after receiving stream close for the second time") } streamCount := sesh.streamCount() @@ -245,8 +234,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) + + seshConfig := seshConfigs["ordered"] + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) // receive stream 1 closing first f1CloseStream := &Frame{ @@ -260,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) { if err != nil { t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) } - _, ok := sesh.streams.Load(f1CloseStream.StreamID) + sesh.streamsM.Lock() + _, ok := sesh.streams[f1CloseStream.StreamID] + sesh.streamsM.Unlock() if !ok { t.Fatal("stream 1 doesn't exist") } @@ -300,119 +293,126 @@ func TestParallelStreams(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - numStreams := acceptBacklog - seqs := make([]*uint64, numStreams) - for i := range seqs { - seqs[i] = new(uint64) - } - randFrame := func() *Frame { - id := rand.Intn(numStreams) - return &Frame{ - uint32(id), - atomic.AddUint64(seqs[id], 1) - 1, - uint8(rand.Intn(2)), - []byte{1, 2, 3, 4}, - } - } + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) - const numOfTests = 5000 - tests := make([]struct { - name string - frame *Frame - }, numOfTests) - for i := range tests { - tests[i].name = strconv.Itoa(i) - tests[i].frame = randFrame() - } - - var wg sync.WaitGroup - for _, tc := range tests { - wg.Add(1) - go func(frame *Frame) { - obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.Obfs(frame, obfsBuf, 0) - obfsBuf = obfsBuf[0:n] - - err := sesh.recvDataFromRemote(obfsBuf) - if err != nil { - t.Error(err) + numStreams := acceptBacklog + seqs := make([]*uint64, numStreams) + for i := range seqs { + seqs[i] = new(uint64) + } + randFrame := func() *Frame { + id := rand.Intn(numStreams) + return &Frame{ + uint32(id), + atomic.AddUint64(seqs[id], 1) - 1, + uint8(rand.Intn(2)), + []byte{1, 2, 3, 4}, + } } - wg.Done() - }(tc.frame) - } - wg.Wait() - sc := int(sesh.streamCount()) - var count int - sesh.streams.Range(func(_, s interface{}) bool { - if s != nil { - count++ - } - return true - }) - if sc != count { - t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + const numOfTests = 5000 + tests := make([]struct { + name string + frame *Frame + }, numOfTests) + for i := range tests { + tests[i].name = strconv.Itoa(i) + tests[i].frame = randFrame() + } + + var wg sync.WaitGroup + for _, tc := range tests { + wg.Add(1) + go func(frame *Frame) { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(frame, obfsBuf, 0) + obfsBuf = obfsBuf[0:n] + + err := sesh.recvDataFromRemote(obfsBuf) + if err != nil { + t.Error(err) + } + wg.Done() + }(tc.frame) + } + + wg.Wait() + sc := int(sesh.streamCount()) + var count int + sesh.streamsM.Lock() + for _, s := range sesh.streams { + if s != nil { + count++ + } + } + sesh.streamsM.Unlock() + if sc != count { + t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) + } + }) } } func TestStream_SetReadDeadline(t *testing.T) { - var sessionKey [32]byte - rand.Read(sessionKey[:]) - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) - testReadDeadline := func(sesh *Session) { - t.Run("read after deadline set", func(t *testing.T) { - stream, _ := sesh.OpenStream() - _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) - _, err := stream.Read(make([]byte, 1)) - if err != ErrTimeout { - t.Errorf("expecting error %v, got %v", ErrTimeout, err) - } - }) + t.Run("read after deadline set", func(t *testing.T) { + stream, _ := sesh.OpenStream() + _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) + _, err := stream.Read(make([]byte, 1)) + if err != ErrTimeout { + t.Errorf("expecting error %v, got %v", ErrTimeout, err) + } + }) - t.Run("unblock when deadline passed", func(t *testing.T) { - stream, _ := sesh.OpenStream() + t.Run("unblock when deadline passed", func(t *testing.T) { + stream, _ := sesh.OpenStream() - done := make(chan struct{}) - go func() { - _, _ = stream.Read(make([]byte, 1)) - done <- struct{}{} - }() + done := make(chan struct{}) + go func() { + _, _ = stream.Read(make([]byte, 1)) + done <- struct{}{} + }() - _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _ = stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - select { - case <-done: - return - case <-time.After(500 * time.Millisecond): - t.Error("Read did not unblock after deadline has passed") - } + select { + case <-done: + return + case <-time.After(500 * time.Millisecond): + t.Error("Read did not unblock after deadline has passed") + } + }) }) } - - sesh := MakeSession(0, seshConfigOrdered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) - sesh = MakeSession(0, seshConfigUnordered) - sesh.AddConnection(connutil.Discard()) - testReadDeadline(sesh) } func TestSession_timeoutAfter(t *testing.T) { var sessionKey [32]byte rand.Read(sessionKey[:]) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond - sesh := MakeSession(0, seshConfigOrdered) - assert.Eventually(t, func() bool { - return sesh.IsClosed() - }, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out") + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + t.Run(seshType, func(t *testing.T) { + seshConfig.Obfuscator = obfuscator + seshConfig.InactivityTimeout = 100 * time.Millisecond + sesh := MakeSession(0, seshConfig) + + assert.Eventually(t, func() bool { + return sesh.IsClosed() + }, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out") + }) + } } func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { @@ -424,47 +424,73 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { 0, testPayload, } - obfsBuf := make([]byte, obfsBufLen) var sessionKey [32]byte rand.Read(sessionKey[:]) - b.Run("plain", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) + const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic + for name, ep := range table { + ep := ep + b.Run(name, func(b *testing.B) { + seshConfig := seshConfigs["ordered"] + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) - b.Run("aes-gcm", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) + binaryFrames := [maxIter][]byte{} + for i := 0; i < maxIter; i++ { + obfsBuf := make([]byte, obfsBufLen) + n, _ := sesh.Obfs(f, obfsBuf, 0) + binaryFrames[i] = obfsBuf[:n] + f.Seq++ + } - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) - - b.Run("chacha20-poly1305", func(b *testing.B) { - obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey) - seshConfigOrdered.Obfuscator = obfuscator - sesh := MakeSession(0, seshConfigOrdered) - n, _ := sesh.Obfs(f, obfsBuf, 0) - - b.SetBytes(int64(len(f.Payload))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - sesh.recvDataFromRemote(obfsBuf[:n]) - } - }) + b.SetBytes(int64(len(f.Payload))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sesh.recvDataFromRemote(binaryFrames[i]) + } + }) + } +} + +func BenchmarkMultiStreamWrite(b *testing.B) { + var sessionKey [32]byte + rand.Read(sessionKey[:]) + + table := map[string]byte{ + "plain": EncryptionMethodPlain, + "aes-gcm": EncryptionMethodAESGCM, + "chacha20poly1305": EncryptionMethodChaha20Poly1305, + } + + testPayload := make([]byte, testPayloadLen) + + for name, ep := range table { + b.Run(name, func(b *testing.B) { + for seshType, seshConfig := range seshConfigs { + seshConfig := seshConfig + b.Run(seshType, func(b *testing.B) { + obfuscator, _ := MakeObfuscator(ep, sessionKey) + seshConfig.Obfuscator = obfuscator + sesh := MakeSession(0, seshConfig) + sesh.AddConnection(connutil.Discard()) + b.ResetTimer() + b.SetBytes(testPayloadLen) + b.RunParallel(func(pb *testing.PB) { + stream, _ := sesh.OpenStream() + for pb.Next() { + stream.Write(testPayload) + } + }) + }) + } + }) + } } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index beee2b8..d827117 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -23,21 +23,20 @@ type Stream struct { session *Session + allocIdempot sync.Once // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't - // been read by the consumer through Read or WriteTo + // been read by the consumer through Read or WriteTo. Lazily allocated recvBuf recvBuffer - writingM sync.Mutex - nextSendSeq uint64 + writingM sync.Mutex + writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom // atomic closed uint32 - // lazy allocation for obfsBuf. This is desirable because obfsBuf is only used when data is sent from + // obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from // the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste // memory - allocIdempot sync.Once - // obfuscation happens in this buffer obfsBuf []byte // When we want order guarantee (i.e. session.Unordered is false), @@ -52,17 +51,14 @@ type Stream struct { } func makeStream(sesh *Session, id uint32) *Stream { - var recvBuf recvBuffer - if sesh.Unordered { - recvBuf = NewDatagramBufferedPipe() - } else { - recvBuf = NewStreamBuffer() - } - stream := &Stream{ id: id, session: sesh, - recvBuf: recvBuf, + writingFrame: Frame{ + StreamID: id, + Seq: 0, + Closing: closingNothing, + }, } return stream @@ -70,9 +66,20 @@ func makeStream(sesh *Session, id uint32) *Stream { func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } +func (s *Stream) getRecvBuf() recvBuffer { + s.allocIdempot.Do(func() { + if s.session.Unordered { + s.recvBuf = NewDatagramBufferedPipe() + } else { + s.recvBuf = NewStreamBuffer() + } + }) + return s.recvBuf +} + // receive a readily deobfuscated Frame so its payload can later be Read -func (s *Stream) recvFrame(frame Frame) error { - toBeClosed, err := s.recvBuf.Write(frame) +func (s *Stream) recvFrame(frame *Frame) error { + toBeClosed, err := s.getRecvBuf().Write(frame) if toBeClosed { err = s.passiveClose() if errors.Is(err, errRepeatStreamClosing) { @@ -91,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return 0, nil } - n, err = s.recvBuf.Read(buf) + n, err = s.getRecvBuf().Read(buf) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -102,7 +109,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { // WriteTo continuously write data Stream has received into the writer w. func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed - n, err := s.recvBuf.WriteTo(w) + n, err := s.getRecvBuf().WriteTo(w) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -110,15 +117,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) { return n, nil } -func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error { +func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error { var cipherTextLen int - cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf) + cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf) if err != nil { return err } _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) - log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -154,14 +160,9 @@ func (s *Stream) Write(in []byte) (n int, err error) { } framePayload = in[n : s.session.maxStreamUnitWrite+n] } - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: framePayload, - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, 0) + s.writingFrame.Payload = framePayload + err = s.obfuscateAndSend(0) + s.writingFrame.Seq++ if err != nil { return } @@ -193,14 +194,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) { } s.writingM.Lock() - f := &Frame{ - StreamID: s.id, - Seq: s.nextSendSeq, - Closing: closingNothing, - Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read], - } - s.nextSendSeq++ - err = s.obfuscateAndSend(f, frameHeaderLength) + s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read] + err = s.obfuscateAndSend(frameHeaderLength) + s.writingFrame.Seq++ s.writingM.Unlock() if err != nil { @@ -225,8 +221,8 @@ func (s *Stream) Close() error { func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } -func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } -func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } +func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) } +func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } var errNotImplemented = errors.New("Not implemented") diff --git a/internal/multiplex/streamBuffer.go b/internal/multiplex/streamBuffer.go index 4adfae2..13cc523 100644 --- a/internal/multiplex/streamBuffer.go +++ b/internal/multiplex/streamBuffer.go @@ -63,7 +63,7 @@ func NewStreamBuffer() *streamBuffer { return sb } -func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { +func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) { sb.recvM.Lock() defer sb.recvM.Unlock() // when there'fs no ooo packages in heap and we receive the next package in order @@ -81,10 +81,11 @@ func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) } - heap.Push(&sb.sh, &f) + saved := *f + heap.Push(&sb.sh, &saved) // Keep popping from the heap until empty or to the point that the wanted seq was not received for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { - f = *heap.Pop(&sb.sh).(*Frame) + f = heap.Pop(&sb.sh).(*Frame) if f.Closing != closingNothing { return true, nil } else { diff --git a/internal/multiplex/streamBuffer_test.go b/internal/multiplex/streamBuffer_test.go index 67fb3a5..b36bb6a 100644 --- a/internal/multiplex/streamBuffer_test.go +++ b/internal/multiplex/streamBuffer_test.go @@ -20,11 +20,10 @@ func TestRecvNewFrame(t *testing.T) { for _, n := range set { bu64 := make([]byte, 8) binary.BigEndian.PutUint64(bu64, n) - frame := Frame{ + sb.Write(&Frame{ Seq: n, Payload: bu64, - } - sb.Write(frame) + }) } var sortedResult []uint64 @@ -80,7 +79,7 @@ func TestStreamBuffer_RecvThenClose(t *testing.T) { Closing: 0, Payload: testData, } - sb.Write(testFrame) + sb.Write(&testFrame) sb.Close() readBuf := make([]byte, testDataLen) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 6ce16d9..c0b86fb 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -151,19 +151,31 @@ func TestStream_Close(t *testing.T) { t.Error("failed to accept stream", err) return } + + // we read something to wait for the test frame to reach our recvBuffer. + // if it's empty by the point we call stream.Close(), the incoming + // frame will be dropped + readBuf := make([]byte, len(testPayload)) + _, err = io.ReadFull(stream, readBuf[:1]) + if err != nil { + t.Errorf("can't read any data before active closing") + } + err = stream.Close() if err != nil { t.Error("failed to actively close stream", err) return } - if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { + sesh.streamsM.Lock() + if s, _ := sesh.streams[stream.(*Stream).id]; s != nil { + sesh.streamsM.Unlock() t.Error("stream still exists") return } + sesh.streamsM.Unlock() - readBuf := make([]byte, len(testPayload)) - _, err = io.ReadFull(stream, readBuf) + _, err = io.ReadFull(stream, readBuf[1:]) if err != nil { t.Errorf("can't read residual data %v", err) } @@ -233,8 +245,10 @@ func TestStream_Close(t *testing.T) { } assert.Eventually(t, func() bool { - sI, _ := sesh.streams.Load(stream.(*Stream).id) - return sI == nil + sesh.streamsM.Lock() + s, _ := sesh.streams[stream.(*Stream).id] + sesh.streamsM.Unlock() + return s == nil }, time.Second, 10*time.Millisecond, "streams still exists") }) diff --git a/internal/server/TLS.go b/internal/server/TLS.go index 8a0ea6a..0e66387 100644 --- a/internal/server/TLS.go +++ b/internal/server/TLS.go @@ -79,7 +79,13 @@ func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fr return } - copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) + var sharedSecret []byte + sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub) + if err != nil { + return + } + + copy(fragments.sharedSecret[:], sharedSecret) var keyShare []byte keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}]) if err != nil { diff --git a/internal/server/state.go b/internal/server/state.go index 576e326..03d9298 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -143,9 +143,14 @@ func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, er err = errors.New("command & control mode not implemented") return } else { - manager, err := usermanager.MakeLocalManager(preParse.DatabasePath, worldState) - if err != nil { - return sta, err + var manager usermanager.UserManager + if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" { + manager = &usermanager.Voidmanager{} + } else { + manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState) + if err != nil { + return sta, err + } } sta.Panel = MakeUserPanel(manager) } diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index bfd5cc4..7bf84d5 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -40,6 +40,7 @@ const ( var ErrUserNotFound = errors.New("UID does not correspond to a user") var ErrSessionsCapReached = errors.New("Sessions cap has reached") +var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified") var ErrNoUpCredit = errors.New("No upload credit left") var ErrNoDownCredit = errors.New("No download credit left") diff --git a/internal/server/usermanager/voidmanager.go b/internal/server/usermanager/voidmanager.go new file mode 100644 index 0000000..a20ab3c --- /dev/null +++ b/internal/server/usermanager/voidmanager.go @@ -0,0 +1,31 @@ +package usermanager + +type Voidmanager struct{} + +func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) { + return 0, 0, ErrMangerIsVoid +} + +func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error { + return ErrMangerIsVoid +} + +func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) { + return nil, ErrMangerIsVoid +} + +func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) { + return nil, ErrMangerIsVoid +} + +func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) { + return UserInfo{}, ErrMangerIsVoid +} + +func (v *Voidmanager) WriteUserInfo(info UserInfo) error { + return ErrMangerIsVoid +} + +func (v *Voidmanager) DeleteUser(UID []byte) error { + return ErrMangerIsVoid +} diff --git a/internal/server/usermanager/voidmanager_test.go b/internal/server/usermanager/voidmanager_test.go new file mode 100644 index 0000000..55ab2b4 --- /dev/null +++ b/internal/server/usermanager/voidmanager_test.go @@ -0,0 +1,43 @@ +package usermanager + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +var v = &Voidmanager{} + +func Test_Voidmanager_AuthenticateUser(t *testing.T) { + _, _, err := v.AuthenticateUser([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_AuthoriseNewSession(t *testing.T) { + err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_DeleteUser(t *testing.T) { + err := v.DeleteUser([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_GetUserInfo(t *testing.T) { + _, err := v.GetUserInfo([]byte{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_ListAllUsers(t *testing.T) { + _, err := v.ListAllUsers() + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_UploadStatus(t *testing.T) { + _, err := v.UploadStatus([]StatusUpdate{}) + assert.Equal(t, ErrMangerIsVoid, err) +} + +func Test_Voidmanager_WriteUserInfo(t *testing.T) { + err := v.WriteUserInfo(UserInfo{}) + assert.Equal(t, ErrMangerIsVoid, err) +} diff --git a/internal/server/userpanel.go b/internal/server/userpanel.go index 453ff39..953179e 100644 --- a/internal/server/userpanel.go +++ b/internal/server/userpanel.go @@ -185,6 +185,9 @@ func (panel *userPanel) commitUpdate() error { panel.usageUpdateQueue = make(map[[16]byte]*usagePair) panel.usageUpdateQueueM.Unlock() + if len(statuses) == 0 { + return nil + } responses, err := panel.Manager.UploadStatus(statuses) if err != nil { return err diff --git a/internal/server/websocket.go b/internal/server/websocket.go index 2b192b9..1c9e940 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -84,7 +84,13 @@ func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fra return } - copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) + var sharedSecret []byte + sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub) + if err != nil { + return + } + + copy(fragments.sharedSecret[:], sharedSecret) if len(hidden[32:]) != 64 { err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:])) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index db58b39..5812ba2 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -12,10 +12,8 @@ import ( "github.com/cbeuw/connutil" "github.com/stretchr/testify/assert" "io" - "io/ioutil" "math/rand" "net" - "os" "sync" "testing" "time" @@ -24,8 +22,6 @@ import ( ) const numConns = 200 // -race option limits the number of goroutines to 8192 -const delayBeforeTestingConnClose = 500 * time.Millisecond -const connCloseRetries = 3 func serveTCPEcho(l net.Listener) { for { @@ -137,17 +133,15 @@ func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState) return lcl, rmt, auth } -func basicServerState(ws common.WorldState, db *os.File) *server.State { +func basicServerState(ws common.WorldState) *server.State { var serverConfig = server.RawConfig{ - ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}}, - BindAddr: []string{"fake.com:9999"}, - BypassUID: [][]byte{bypassUID[:]}, - RedirAddr: "fake.com:9999", - PrivateKey: privateKey, - AdminUID: nil, - DatabasePath: db.Name(), - KeepAlive: 15, - CncMode: false, + ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}}, + BindAddr: []string{"fake.com:9999"}, + BypassUID: [][]byte{bypassUID[:]}, + RedirAddr: "fake.com:9999", + PrivateKey: privateKey, + KeepAlive: 15, + CncMode: false, } state, err := server.InitState(serverConfig, ws) if err != nil { @@ -258,13 +252,11 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { } func TestUDP(t *testing.T) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { @@ -319,9 +311,7 @@ func TestTCPSingleplex(t *testing.T) { log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState) - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { t.Fatal(err) @@ -381,9 +371,7 @@ func TestTCPMultiplex(t *testing.T) { worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta) if err != nil { @@ -456,11 +444,8 @@ func TestClosingStreamsFromProxy(t *testing.T) { clientConfig := clientConfig clientConfigName := clientConfigName t.Run(clientConfigName, func(t *testing.T) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) - lcc, rcc, ai := generateClientConfigs(clientConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) if err != nil { t.Fatal(err) @@ -519,12 +504,10 @@ func TestClosingStreamsFromProxy(t *testing.T) { } func BenchmarkThroughput(b *testing.B) { - var tmpDB, _ = ioutil.TempFile("", "ck_user_info") - defer os.Remove(tmpDB.Name()) log.SetLevel(log.ErrorLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) - sta := basicServerState(worldState, tmpDB) + sta := basicServerState(worldState) const bufSize = 16 * 1024 encryptionMethods := map[string]byte{ diff --git a/release.sh b/release.sh index 277f87a..bee82f8 100755 --- a/release.sh +++ b/release.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + go get github.com/mitchellh/gox mkdir -p release