diff --git a/README.md b/README.md index ebf70db..f380dac 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,19 @@ encryption and authentication (via AEAD or similar techniques).** `ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. +`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new +connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects +the alternative domains. + +Example: + +```json +{ + "ServerName": "bing.com", + "AlternativeNames": ["cloudflare.com", "github.com"] +} +``` + `CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with diff --git a/cmd/ck-client/ck-client.go b/cmd/ck-client/ck-client.go index 2c19cd3..38bf95d 100644 --- a/cmd/ck-client/ck-client.go +++ b/cmd/ck-client/ck-client.go @@ -173,6 +173,11 @@ func main() { log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) seshMaker = func() *mux.Session { authInfo := authInfo // copy the struct because we are overwriting SessionId + + randByte := make([]byte, 1) + common.RandRead(authInfo.WorldState.Rand, randByte) + authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)] + // sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is limited to its UID. quad := make([]byte, 4) diff --git a/cmd/ck-client/protector_android.go b/cmd/ck-client/protector_android.go index 639b98c..fbaea7b 100644 --- a/cmd/ck-client/protector_android.go +++ b/cmd/ck-client/protector_android.go @@ -1,4 +1,5 @@ // +build android + package main // Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go diff --git a/cmd/ck-server/ck-server_test.go b/cmd/ck-server/ck-server_test.go index 09de559..87427ee 100644 --- a/cmd/ck-server/ck-server_test.go +++ b/cmd/ck-server/ck-server_test.go @@ -9,49 +9,27 @@ import ( func TestParseBindAddr(t *testing.T) { t.Run("port only", func(t *testing.T) { addrs, err := resolveBindAddr([]string{":443"}) - if err != nil { - t.Error(err) - return - } - if addrs[0].String() != ":443" { - t.Errorf("expected %v got %v", ":443", addrs[0].String()) - } + assert.NoError(t, err) + assert.Equal(t, ":443", addrs[0].String()) }) t.Run("specific address", func(t *testing.T) { addrs, err := resolveBindAddr([]string{"192.168.1.123:443"}) - if err != nil { - t.Error(err) - return - } - if addrs[0].String() != "192.168.1.123:443" { - t.Errorf("expected %v got %v", "192.168.1.123:443", addrs[0].String()) - } + assert.NoError(t, err) + assert.Equal(t, "192.168.1.123:443", addrs[0].String()) }) t.Run("ipv6", func(t *testing.T) { addrs, err := resolveBindAddr([]string{"[::]:443"}) - if err != nil { - t.Error(err) - return - } - if addrs[0].String() != "[::]:443" { - t.Errorf("expected %v got %v", "[::]:443", addrs[0].String()) - } + assert.NoError(t, err) + assert.Equal(t, "[::]:443", addrs[0].String()) }) t.Run("mixed", func(t *testing.T) { addrs, err := resolveBindAddr([]string{":80", "[::]:443"}) - if err != nil { - t.Error(err) - return - } - if addrs[0].String() != ":80" { - t.Errorf("expected %v got %v", ":80", addrs[0].String()) - } - if addrs[1].String() != "[::]:443" { - t.Errorf("expected %v got %v", "[::]:443", addrs[1].String()) - } + assert.NoError(t, err) + assert.Equal(t, ":80", addrs[0].String()) + assert.Equal(t, "[::]:443", addrs[1].String()) }) } diff --git a/internal/client/TLS_test.go b/internal/client/TLS_test.go index 9093e8a..b8bdd81 100644 --- a/internal/client/TLS_test.go +++ b/internal/client/TLS_test.go @@ -1,8 +1,8 @@ package client import ( - "bytes" "encoding/hex" + "github.com/stretchr/testify/assert" "testing" ) @@ -33,11 +33,6 @@ func TestMakeServerName(t *testing.T) { } for _, p := range pairs { - if !bytes.Equal(makeServerName(p.serverName), p.target) { - t.Error( - "for", p.serverName, - "expecting", p.target, - "got", makeServerName(p.serverName)) - } + assert.Equal(t, p.target, makeServerName(p.serverName)) } } diff --git a/internal/client/auth_test.go b/internal/client/auth_test.go index eda0c2c..4c7da33 100644 --- a/internal/client/auth_test.go +++ b/internal/client/auth_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/multiplex" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -64,12 +65,8 @@ func TestMakeAuthenticationPayload(t *testing.T) { for _, tc := range tests { func() { payload, sharedSecret := makeAuthenticationPayload(tc.authInfo) - if payload != tc.expPayload { - t.Errorf("payload doesn't match:\nexp %v\ngot %v", tc.expPayload, payload) - } - if sharedSecret != tc.expSecret { - t.Errorf("secret doesn't match:\nexp %x\ngot %x", tc.expPayload, payload) - } + assert.Equal(t, tc.expPayload, payload, "payload doesn't match") + assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match") }() } } diff --git a/internal/client/state.go b/internal/client/state.go index 0ee914c..0ab8270 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -26,11 +26,11 @@ type RawConfig struct { UID []byte PublicKey []byte NumConn int - LocalHost string // jsonOptional - LocalPort string // jsonOptional - RemoteHost string // jsonOptional - RemotePort string // jsonOptional - + LocalHost string // jsonOptional + LocalPort string // jsonOptional + RemoteHost string // jsonOptional + RemotePort string // jsonOptional + AlternativeNames []string // jsonOptional // defaults set in ProcessRawConfig UDP bool // nullable BrowserSig string // nullable @@ -49,8 +49,9 @@ type RemoteConnConfig struct { } type LocalConnConfig struct { - LocalAddr string - Timeout time.Duration + LocalAddr string + Timeout time.Duration + MockDomainList []string } type AuthInfo struct { @@ -94,6 +95,20 @@ func ssvToJson(ssv string) (ret []byte) { } key := sp[0] value := sp[1] + if strings.HasPrefix(key, "AlternativeNames") { + switch strings.Contains(value, ",") { + case true: + domains := strings.Split(value, ",") + for index, domain := range domains { + domains[index] = `"` + domain + `"` + } + value = strings.Join(domains, ",") + ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...) + case false: + ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...) + } + continue + } // JSON doesn't like quotation marks around int and bool // This is extremely ugly but it's still better than writing a tokeniser if elem(key, unquoted) { @@ -139,6 +154,8 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca return nullErr("ServerName") } auth.MockDomain = raw.ServerName + local.MockDomainList = raw.AlternativeNames + local.MockDomainList = append(local.MockDomainList, auth.MockDomain) if raw.ProxyMethod == "" { return nullErr("ServerName") } diff --git a/internal/common/tls.go b/internal/common/tls.go index fb54e97..3953992 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -2,6 +2,7 @@ package common import ( "encoding/binary" + "errors" "io" "net" "sync" @@ -94,6 +95,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) + if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2 + return 0, errors.New("message is too long") + } writeBuf := tls.writeBufPool.Get().(*[]byte) *writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF)) *writeBuf = append(*writeBuf, in...) diff --git a/internal/multiplex/datagramBufferedPipe_test.go b/internal/multiplex/datagramBufferedPipe_test.go index 6b20f76..8d7a07b 100644 --- a/internal/multiplex/datagramBufferedPipe_test.go +++ b/internal/multiplex/datagramBufferedPipe_test.go @@ -1,7 +1,7 @@ package multiplex import ( - "bytes" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -11,13 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) { t.Run("simple write", func(t *testing.T) { pipe := NewDatagramBufferedPipe() _, err := pipe.Write(&Frame{Payload: b}) - if err != nil { - t.Error( - "expecting", "nil error", - "got", err, - ) - return - } + assert.NoError(t, err) }) t.Run("simple read", func(t *testing.T) { @@ -25,50 +19,18 @@ func TestDatagramBuffer_RW(t *testing.T) { _, _ = pipe.Write(&Frame{Payload: b}) b2 := make([]byte, len(b)) n, err := pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read", - "expecting", len(b), - "got", n, - ) - return - } - if err != nil { - t.Error( - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "expecting", b, - "got", b2, - ) - } - if pipe.buf.Len() != 0 { - t.Error("buf len is not 0 after finished reading") - return - } - + assert.NoError(t, err) + assert.Equal(t, len(b), n) + assert.Equal(t, b, b2) + assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading") }) t.Run("writing closing frame", func(t *testing.T) { pipe := NewDatagramBufferedPipe() toBeClosed, err := pipe.Write(&Frame{Closing: closingStream}) - if !toBeClosed { - t.Error("should be to be closed") - } - if err != nil { - t.Error( - "expecting", "nil error", - "got", err, - ) - return - } - if !pipe.closed { - t.Error("expecting closed pipe, not closed") - } + assert.NoError(t, err) + assert.True(t, toBeClosed, "should be to be closed") + assert.True(t, pipe.closed, "pipe should be closed") }) } @@ -81,30 +43,9 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) { }() b2 := make([]byte, len(b)) n, err := pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read after block", - "expecting", len(b), - "got", n, - ) - return - } - if err != nil { - t.Error( - "For", "blocked read", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "For", "blocked read", - "expecting", b, - "got", b2, - ) - return - } + assert.NoError(t, err) + assert.Equal(t, len(b), n, "number of bytes read after block is wrong") + assert.Equal(t, b, b2) } func TestDatagramBuffer_CloseThenRead(t *testing.T) { @@ -114,27 +55,7 @@ func TestDatagramBuffer_CloseThenRead(t *testing.T) { b2 := make([]byte, len(b)) pipe.Close() n, err := pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read", - "expecting", len(b), - "got", n, - ) - } - if err != nil { - t.Error( - "For", "simple read", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "For", "simple read", - "expecting", b, - "got", b2, - ) - return - } + assert.NoError(t, err) + assert.Equal(t, len(b), n, "number of bytes read after block is wrong") + assert.Equal(t, b, b2) } diff --git a/internal/multiplex/mux_test.go b/internal/multiplex/mux_test.go index e492305..6bb57ed 100644 --- a/internal/multiplex/mux_test.go +++ b/internal/multiplex/mux_test.go @@ -108,9 +108,7 @@ func TestMultiplex(t *testing.T) { streams := make([]net.Conn, numStreams) for i := 0; i < numStreams; i++ { stream, err := clientSession.OpenStream() - if err != nil { - t.Fatalf("failed to open stream: %v", err) - } + assert.NoError(t, err) streams[i] = stream } @@ -123,18 +121,11 @@ func TestMultiplex(t *testing.T) { // close one stream closing, streams := streams[0], streams[1:] err := closing.Close() - if err != nil { - t.Errorf("couldn't close a stream") - } + assert.NoError(t, err, "couldn't close a stream") _, err = closing.Write([]byte{0}) - if err != ErrBrokenStream { - t.Errorf("expecting error %v, got %v", ErrBrokenStream, err) - } + assert.Equal(t, ErrBrokenStream, err) _, err = closing.Read(make([]byte, 1)) - if err != ErrBrokenStream { - t.Errorf("expecting error %v, got %v", ErrBrokenStream, err) - } - + assert.Equal(t, ErrBrokenStream, err) } func TestMux_StreamClosing(t *testing.T) { @@ -146,20 +137,13 @@ func TestMux_StreamClosing(t *testing.T) { recvBuf := make([]byte, 128) toBeClosed, _ := clientSession.OpenStream() _, err := toBeClosed.Write(testData) // should be echoed back - if err != nil { - t.Errorf("can't write to stream: %v", err) - } + assert.NoError(t, err, "couldn't write to a stream") _, err = io.ReadFull(toBeClosed, recvBuf[:1]) - if err != nil { - t.Errorf("can't read anything before stream closed: %v", err) - } + assert.NoError(t, err, "can't read anything before stream closed") + _ = toBeClosed.Close() _, err = io.ReadFull(toBeClosed, recvBuf[1:]) - if err != nil { - t.Errorf("can't read residual data on stream: %v", err) - } - if !bytes.Equal(testData, recvBuf) { - t.Errorf("incorrect data read back") - } + assert.NoError(t, err, "can't read residual data on stream") + assert.Equal(t, testData, recvBuf, "incorrect data read back") } diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index 2dad728..76ba3bc 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -138,7 +138,7 @@ func BenchmarkObfs(b *testing.B) { testPayload, } - obfsBuf := make([]byte, defaultSendRecvBufSize) + obfsBuf := make([]byte, len(testPayload)*2) var key [32]byte rand.Read(key[:]) @@ -211,7 +211,7 @@ func BenchmarkDeobfs(b *testing.B) { testPayload, } - obfsBuf := make([]byte, defaultSendRecvBufSize) + obfsBuf := make([]byte, len(testPayload)*2) var key [32]byte rand.Read(key[:]) diff --git a/internal/multiplex/recvBuffer.go b/internal/multiplex/recvBuffer.go index 63f1f6f..91af149 100644 --- a/internal/multiplex/recvBuffer.go +++ b/internal/multiplex/recvBuffer.go @@ -25,4 +25,4 @@ type recvBuffer interface { // size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks. // If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write, // a panic will happen. -const recvBufferSizeLimit = defaultSendRecvBufSize << 12 +const recvBufferSizeLimit = 1<<31 - 1 diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 6abc90e..f530f19 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -13,10 +13,9 @@ import ( ) const ( - acceptBacklog = 1024 - // TODO: will this be a signature? - defaultSendRecvBufSize = 20480 + acceptBacklog = 1024 defaultInactivityTimeout = 30 * time.Second + defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2 ) var ErrBrokenSession = errors.New("broken session") @@ -24,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session") var errRepeatStreamClosing = errors.New("trying to close a closed stream") var errNoMultiplex = errors.New("a singleplexing session can have only one stream") -type switchboardStrategy int - type SessionConfig struct { Obfuscator @@ -40,12 +37,6 @@ type SessionConfig struct { // maximum size of an obfuscated frame, including headers and overhead MsgOnWireSizeLimit int - // StreamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf) - StreamSendBufferSize int - // ConnReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in - // switchboard.deplex) - ConnReceiveBufferSize int - // InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself InactivityTimeout time.Duration } @@ -82,11 +73,17 @@ type Session struct { closed uint32 - terminalMsg atomic.Value + terminalMsgSetter sync.Once + terminalMsg string // the max size passed to Write calls before it splits it into multiple frames // i.e. the max size a piece of data can fit into a Frame.Payload maxStreamUnitWrite int + // streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf) + streamSendBufferSize int + // connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in + // switchboard.deplex) + connReceiveBufferSize int } func MakeSession(id uint32, config SessionConfig) *Session { @@ -103,23 +100,19 @@ func MakeSession(id uint32, config SessionConfig) *Session { if config.Valve == nil { sesh.Valve = UNLIMITED_VALVE } - if config.StreamSendBufferSize <= 0 { - sesh.StreamSendBufferSize = defaultSendRecvBufSize - } - if config.ConnReceiveBufferSize <= 0 { - sesh.ConnReceiveBufferSize = defaultSendRecvBufSize - } if config.MsgOnWireSizeLimit <= 0 { - sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024 + sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize } if config.InactivityTimeout == 0 { sesh.InactivityTimeout = defaultInactivityTimeout } - // todo: validation. this must be smaller than StreamSendBufferSize - sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead + + sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.maxOverhead + sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit + sesh.connReceiveBufferSize = 20480 // for backwards compatibility sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { - b := make([]byte, sesh.StreamSendBufferSize) + b := make([]byte, sesh.streamSendBufferSize) return &b }} @@ -187,7 +180,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) } - _ = s.getRecvBuf().Close() // recvBuf.Close should not return error + _ = s.recvBuf.Close() // recvBuf.Close should not return error if active { tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte) @@ -271,16 +264,13 @@ func (sesh *Session) recvDataFromRemote(data []byte) error { } func (sesh *Session) SetTerminalMsg(msg string) { - sesh.terminalMsg.Store(msg) + sesh.terminalMsgSetter.Do(func() { + sesh.terminalMsg = msg + }) } func (sesh *Session) TerminalMsg() string { - msg := sesh.terminalMsg.Load() - if msg != nil { - return msg.(string) - } else { - return "" - } + return sesh.terminalMsg } func (sesh *Session) closeSession() error { @@ -292,13 +282,11 @@ func (sesh *Session) closeSession() error { sesh.streamsM.Lock() close(sesh.acceptCh) for id, stream := range sesh.streams { - if stream == nil { - continue + if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) { + _ = stream.recvBuf.Close() // will not block + delete(sesh.streams, id) + sesh.streamCountDecr() } - atomic.StoreUint32(&stream.closed, 1) - _ = stream.getRecvBuf().Close() // will not block - delete(sesh.streams, id) - sesh.streamCountDecr() } sesh.streamsM.Unlock() return nil @@ -339,7 +327,7 @@ func (sesh *Session) Close() error { if err != nil { return err } - _, err = sesh.sb.send((*buf)[:i], new(uint32)) + _, err = sesh.sb.send((*buf)[:i], new(net.Conn)) if err != nil { return err } diff --git a/internal/multiplex/session_test.go b/internal/multiplex/session_test.go index dfa3dbb..88572a6 100644 --- a/internal/multiplex/session_test.go +++ b/internal/multiplex/session_test.go @@ -534,7 +534,7 @@ func TestSession_timeoutAfter(t *testing.T) { func BenchmarkRecvDataFromRemote(b *testing.B) { testPayload := make([]byte, testPayloadLen) rand.Read(testPayload) - f := &Frame{ + f := Frame{ 1, 0, 0, @@ -544,12 +544,13 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { var sessionKey [32]byte rand.Read(sessionKey[:]) - const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic + const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic for name, ep := range encryptionMethods { ep := ep b.Run(name, func(b *testing.B) { for seshType, seshConfig := range seshConfigs { b.Run(seshType, func(b *testing.B) { + f := f seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) sesh := MakeSession(0, seshConfig) @@ -561,7 +562,7 @@ func BenchmarkRecvDataFromRemote(b *testing.B) { binaryFrames := [maxIter][]byte{} for i := 0; i < maxIter; i++ { obfsBuf := make([]byte, obfsBufLen) - n, _ := sesh.obfuscate(f, obfsBuf, 0) + n, _ := sesh.obfuscate(&f, obfsBuf, 0) binaryFrames[i] = obfsBuf[:n] f.Seq++ } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index ffd7e23..8c2ea15 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -23,9 +23,8 @@ type Stream struct { session *Session - allocIdempot sync.Once // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't - // been read by the consumer through Read or WriteTo. Lazily allocated + // been read by the consumer through Read or WriteTo. recvBuf recvBuffer writingM sync.Mutex @@ -40,7 +39,7 @@ type Stream struct { // recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets // so it won't have to use its priority queue to sort it. // This is not used in unordered connection mode - assignedConnId uint32 + assignedConn net.Conn readFromTimeout time.Duration } @@ -56,25 +55,20 @@ func makeStream(sesh *Session, id uint32) *Stream { }, } + if sesh.Unordered { + stream.recvBuf = NewDatagramBufferedPipe() + } else { + stream.recvBuf = NewStreamBuffer() + } + return stream } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } -func (s *Stream) getRecvBuf() recvBuffer { - s.allocIdempot.Do(func() { - if s.session.Unordered { - s.recvBuf = NewDatagramBufferedPipe() - } else { - s.recvBuf = NewStreamBuffer() - } - }) - return s.recvBuf -} - // receive a readily deobfuscated Frame so its payload can later be Read func (s *Stream) recvFrame(frame *Frame) error { - toBeClosed, err := s.getRecvBuf().Write(frame) + toBeClosed, err := s.recvBuf.Write(frame) if toBeClosed { err = s.passiveClose() if errors.Is(err, errRepeatStreamClosing) { @@ -93,7 +87,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { return 0, nil } - n, err = s.getRecvBuf().Read(buf) + n, err = s.recvBuf.Read(buf) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -104,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) { // WriteTo continuously write data Stream has received into the writer w. func (s *Stream) WriteTo(w io.Writer) (int64, error) { // will keep writing until the underlying buffer is closed - n, err := s.getRecvBuf().WriteTo(w) + n, err := s.recvBuf.WriteTo(w) log.Tracef("%v read from stream %v with err %v", n, s.id, err) if err == io.EOF { return n, ErrBrokenStream @@ -119,7 +113,7 @@ func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error { return err } - _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId) + _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn) if err != nil { if err == errBrokenSwitchboard { s.session.SetTerminalMsg(err.Error()) @@ -215,8 +209,8 @@ func (s *Stream) Close() error { func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } -func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) } -func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil } +func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } +func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } var errNotImplemented = errors.New("Not implemented") diff --git a/internal/multiplex/streamBufferedPipe_test.go b/internal/multiplex/streamBufferedPipe_test.go index ff0ec24..6cdbff1 100644 --- a/internal/multiplex/streamBufferedPipe_test.go +++ b/internal/multiplex/streamBufferedPipe_test.go @@ -1,7 +1,7 @@ package multiplex import ( - "bytes" + "github.com/stretchr/testify/assert" "math/rand" "testing" "time" @@ -13,49 +13,15 @@ func TestPipeRW(t *testing.T) { pipe := NewStreamBufferedPipe() b := []byte{0x01, 0x02, 0x03} n, err := pipe.Write(b) - if n != len(b) { - t.Error( - "For", "number of bytes written", - "expecting", len(b), - "got", n, - ) - return - } - if err != nil { - t.Error( - "For", "simple write", - "expecting", "nil error", - "got", err, - ) - return - } + assert.NoError(t, err, "simple write") + assert.Equal(t, len(b), n, "number of bytes written") b2 := make([]byte, len(b)) n, err = pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read", - "expecting", len(b), - "got", n, - ) - return - } - if err != nil { - t.Error( - "For", "simple read", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "For", "simple read", - "expecting", b, - "got", b2, - ) - } + assert.NoError(t, err, "simple read") + assert.Equal(t, len(b), n, "number of bytes read") + assert.Equal(t, b, b2) } func TestReadBlock(t *testing.T) { @@ -67,30 +33,10 @@ func TestReadBlock(t *testing.T) { }() b2 := make([]byte, len(b)) n, err := pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read after block", - "expecting", len(b), - "got", n, - ) - return - } - if err != nil { - t.Error( - "For", "blocked read", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "For", "blocked read", - "expecting", b, - "got", b2, - ) - return - } + assert.NoError(t, err, "blocked read") + assert.Equal(t, len(b), n, "number of bytes read after block") + + assert.Equal(t, b, b2) } func TestPartialRead(t *testing.T) { @@ -99,54 +45,17 @@ func TestPartialRead(t *testing.T) { pipe.Write(b) b1 := make([]byte, 1) n, err := pipe.Read(b1) - if n != len(b1) { - t.Error( - "For", "number of bytes in partial read of 1", - "expecting", len(b1), - "got", n, - ) - return - } - if err != nil { - t.Error( - "For", "partial read of 1", - "expecting", "nil error", - "got", err, - ) - return - } - if b1[0] != b[0] { - t.Error( - "For", "partial read of 1", - "expecting", b[0], - "got", b1[0], - ) - } + assert.NoError(t, err, "partial read of 1") + assert.Equal(t, len(b1), n, "number of bytes in partial read of 1") + + assert.Equal(t, b[0], b1[0]) + b2 := make([]byte, 2) n, err = pipe.Read(b2) - if n != len(b2) { - t.Error( - "For", "number of bytes in partial read of 2", - "expecting", len(b2), - "got", n, - ) - } - if err != nil { - t.Error( - "For", "partial read of 2", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b[1:], b2) { - t.Error( - "For", "partial read of 2", - "expecting", b[1:], - "got", b2, - ) - return - } + assert.NoError(t, err, "partial read of 2") + assert.Equal(t, len(b2), n, "number of bytes in partial read of 2") + + assert.Equal(t, b[1:], b2) } func TestReadAfterClose(t *testing.T) { @@ -156,29 +65,10 @@ func TestReadAfterClose(t *testing.T) { b2 := make([]byte, len(b)) pipe.Close() n, err := pipe.Read(b2) - if n != len(b) { - t.Error( - "For", "number of bytes read", - "expecting", len(b), - "got", n, - ) - } - if err != nil { - t.Error( - "For", "simple read", - "expecting", "nil error", - "got", err, - ) - return - } - if !bytes.Equal(b, b2) { - t.Error( - "For", "simple read", - "expecting", b, - "got", b2, - ) - return - } + assert.NoError(t, err, "simple read") + assert.Equal(t, len(b), n, "number of bytes read") + + assert.Equal(t, b, b2) } func BenchmarkBufferedPipe_RW(b *testing.B) { diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index 84e43c9..fa35567 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -10,9 +10,11 @@ import ( "time" ) +type switchboardStrategy int + const ( - FIXED_CONN_MAPPING switchboardStrategy = iota - UNIFORM_SPREAD + fixedConnMapping switchboardStrategy = iota + uniformSpread ) // switchboard represents the connection pool. It is responsible for managing @@ -28,10 +30,8 @@ type switchboard struct { valve Valve strategy switchboardStrategy - // map of connId to net.Conn conns sync.Map - numConns uint32 - nextConnId uint32 + connsCount uint32 randPool sync.Pool broken uint32 @@ -41,15 +41,14 @@ func makeSwitchboard(sesh *Session) *switchboard { var strategy switchboardStrategy if sesh.Unordered { log.Debug("Connection is unordered") - strategy = UNIFORM_SPREAD + strategy = uniformSpread } else { - strategy = FIXED_CONN_MAPPING + strategy = fixedConnMapping } sb := &switchboard{ - session: sesh, - strategy: strategy, - valve: sesh.Valve, - nextConnId: 1, + session: sesh, + strategy: strategy, + valve: sesh.Valve, randPool: sync.Pool{New: func() interface{} { return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) }}, @@ -59,88 +58,85 @@ func makeSwitchboard(sesh *Session) *switchboard { var errBrokenSwitchboard = errors.New("the switchboard is broken") -func (sb *switchboard) connsCount() int { - return int(atomic.LoadUint32(&sb.numConns)) -} - func (sb *switchboard) addConn(conn net.Conn) { - connId := atomic.AddUint32(&sb.nextConnId, 1) - 1 - atomic.AddUint32(&sb.numConns, 1) - sb.conns.Store(connId, conn) - go sb.deplex(connId, conn) + atomic.AddUint32(&sb.connsCount, 1) + sb.conns.Store(conn, conn) + go sb.deplex(conn) } -// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable -func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { +// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable +func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) { sb.valve.txWait(len(data)) - if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 { + if atomic.LoadUint32(&sb.broken) == 1 { return 0, errBrokenSwitchboard } var conn net.Conn switch sb.strategy { - case UNIFORM_SPREAD: - _, conn, err = sb.pickRandConn() + case uniformSpread: + conn, err = sb.pickRandConn() if err != nil { return 0, errBrokenSwitchboard } - case FIXED_CONN_MAPPING: - connI, ok := sb.conns.Load(*connId) - if ok { - conn = connI.(net.Conn) - } else { - var newConnId uint32 - newConnId, conn, err = sb.pickRandConn() + n, err = conn.Write(data) + if err != nil { + sb.session.SetTerminalMsg("failed to send to remote " + err.Error()) + sb.session.passiveClose() + return n, err + } + case fixedConnMapping: + conn = *assignedConn + if conn == nil { + conn, err = sb.pickRandConn() if err != nil { - return 0, errBrokenSwitchboard + sb.session.SetTerminalMsg("failed to pick a connection " + err.Error()) + sb.session.passiveClose() + return 0, err } - *connId = newConnId + *assignedConn = conn + } + n, err = conn.Write(data) + if err != nil { + sb.session.SetTerminalMsg("failed to send to remote " + err.Error()) + sb.session.passiveClose() + return n, err } default: return 0, errors.New("unsupported traffic distribution strategy") } - n, err = conn.Write(data) - if err != nil { - sb.conns.Delete(*connId) - sb.session.SetTerminalMsg("failed to write to remote " + err.Error()) - sb.session.passiveClose() - return n, err - } sb.valve.AddTx(int64(n)) return n, nil } // returns a random connId -func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { - connCount := sb.connsCount() - if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 { - return 0, nil, errBrokenSwitchboard +func (sb *switchboard) pickRandConn() (net.Conn, error) { + if atomic.LoadUint32(&sb.broken) == 1 { + return nil, errBrokenSwitchboard + } + + connsCount := atomic.LoadUint32(&sb.connsCount) + if connsCount == 0 { + return nil, errBrokenSwitchboard } - // there is no guarantee that sb.conns still has the same amount of entries - // between the count loop and the pick loop - // so if the r > len(sb.conns) at the point of range call, the last visited element is picked - var id uint32 - var conn net.Conn randReader := sb.randPool.Get().(*rand.Rand) - r := randReader.Intn(connCount) + + r := randReader.Intn(int(connsCount)) sb.randPool.Put(randReader) + var c int - sb.conns.Range(func(connIdI, connI interface{}) bool { + var ret net.Conn + sb.conns.Range(func(_, conn interface{}) bool { if r == c { - id = connIdI.(uint32) - conn = connI.(net.Conn) + ret = conn.(net.Conn) return false } c++ return true }) - // if len(sb.conns) is 0 - if conn == nil { - return 0, nil, errBrokenSwitchboard - } - return id, conn, nil + + return ret, nil } // actively triggered by session.Close() @@ -148,26 +144,24 @@ func (sb *switchboard) closeAll() { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { return } - sb.conns.Range(func(key, connI interface{}) bool { - conn := connI.(net.Conn) - conn.Close() - sb.conns.Delete(key) + sb.conns.Range(func(_, conn interface{}) bool { + conn.(net.Conn).Close() + sb.conns.Delete(conn) + atomic.AddUint32(&sb.connsCount, ^uint32(0)) return true }) } // deplex function costantly reads from a TCP connection -func (sb *switchboard) deplex(connId uint32, conn net.Conn) { +func (sb *switchboard) deplex(conn net.Conn) { defer conn.Close() - buf := make([]byte, sb.session.ConnReceiveBufferSize) + buf := make([]byte, sb.session.connReceiveBufferSize) for { n, err := conn.Read(buf) sb.valve.rxWait(n) sb.valve.AddRx(int64(n)) if err != nil { log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) - sb.conns.Delete(connId) - atomic.AddUint32(&sb.numConns, ^uint32(0)) sb.session.SetTerminalMsg("a connection has dropped unexpectedly") sb.session.passiveClose() return diff --git a/internal/multiplex/switchboard_test.go b/internal/multiplex/switchboard_test.go index 2c3f36f..5e95afe 100644 --- a/internal/multiplex/switchboard_test.go +++ b/internal/multiplex/switchboard_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "math/rand" "sync" + "sync/atomic" "testing" "time" ) @@ -14,14 +15,14 @@ func TestSwitchboard_Send(t *testing.T) { sesh := MakeSession(0, seshConfig) hole0 := connutil.Discard() sesh.sb.addConn(hole0) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } data := make([]byte, 1000) rand.Read(data) - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return @@ -29,23 +30,23 @@ func TestSwitchboard_Send(t *testing.T) { hole1 := connutil.Discard() sesh.sb.addConn(hole1) - connId, _, err = sesh.sb.pickRandConn() + conn, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return } - connId, _, err = sesh.sb.pickRandConn() + conn, err = sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return } - _, err = sesh.sb.send(data, &connId) + _, err = sesh.sb.send(data, &conn) if err != nil { t.Error(err) return @@ -71,7 +72,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { seshConfig := SessionConfig{} sesh := MakeSession(0, seshConfig) sesh.sb.addConn(hole) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { b.Error("failed to get a random conn", err) return @@ -81,7 +82,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) { b.SetBytes(int64(len(data))) b.ResetTimer() for i := 0; i < b.N; i++ { - sesh.sb.send(data, &connId) + sesh.sb.send(data, &conn) } } @@ -92,7 +93,7 @@ func TestSwitchboard_TxCredit(t *testing.T) { sesh := MakeSession(0, seshConfig) hole := connutil.Discard() sesh.sb.addConn(hole) - connId, _, err := sesh.sb.pickRandConn() + conn, err := sesh.sb.pickRandConn() if err != nil { t.Error("failed to get a random conn", err) return @@ -100,10 +101,10 @@ func TestSwitchboard_TxCredit(t *testing.T) { data := make([]byte, 1000) rand.Read(data) - t.Run("FIXED CONN MAPPING", func(t *testing.T) { + t.Run("fixed conn mapping", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 - sesh.sb.strategy = FIXED_CONN_MAPPING - n, err := sesh.sb.send(data[:10], &connId) + sesh.sb.strategy = fixedConnMapping + n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err) return @@ -116,10 +117,10 @@ func TestSwitchboard_TxCredit(t *testing.T) { t.Error("tx credit didn't increase by 10") } }) - t.Run("UNIFORM", func(t *testing.T) { + t.Run("uniform spread", func(t *testing.T) { *sesh.sb.valve.(*LimitedValve).tx = 0 - sesh.sb.strategy = UNIFORM_SPREAD - n, err := sesh.sb.send(data[:10], &connId) + sesh.sb.strategy = uniformSpread + n, err := sesh.sb.send(data[:10], &conn) if err != nil { t.Error(err) return @@ -173,13 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) { } wg.Wait() - if sesh.sb.connsCount() != 1000 { + if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 { t.Error("connsCount incorrect") } sesh.sb.closeAll() assert.Eventuallyf(t, func() bool { - return sesh.sb.connsCount() == 0 - }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", sesh.sb.connsCount()) + return atomic.LoadUint32(&sesh.sb.connsCount) == 0 + }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount)) } diff --git a/internal/server/usermanager/api.yaml b/internal/server/usermanager/api.yaml index 39ec6d7..d49875d 100644 --- a/internal/server/usermanager/api.yaml +++ b/internal/server/usermanager/api.yaml @@ -2,7 +2,7 @@ swagger: '2.0' info: description: | This is the API of Cloak server - version: 1.0.0 + version: 0.0.2 title: Cloak Server contact: email: cbeuw.andy@gmail.com @@ -12,8 +12,6 @@ info: # host: petstore.swagger.io # basePath: /v2 tags: - - name: admin - description: Endpoints used by the host administrators - name: users description: Operations related to user controls by admin # schemes: @@ -22,7 +20,6 @@ paths: /admin/users: get: tags: - - admin - users summary: Show all users description: Returns an array of all UserInfo @@ -41,7 +38,6 @@ paths: /admin/users/{UID}: get: tags: - - admin - users summary: Show userinfo by UID description: Returns a UserInfo object @@ -68,7 +64,6 @@ paths: description: internal error post: tags: - - admin - users summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created operationId: writeUserInfo @@ -100,7 +95,6 @@ paths: description: internal error delete: tags: - - admin - users summary: Deletes a user operationId: deleteUser diff --git a/internal/server/usermanager/api_router_test.go b/internal/server/usermanager/api_router_test.go index 1310f60..9f957c8 100644 --- a/internal/server/usermanager/api_router_test.go +++ b/internal/server/usermanager/api_router_test.go @@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) { assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body) }) + t.Run("partial update", func(t *testing.T) { + req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled)) + assert.NoError(t, err) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + + partialUserInfo := UserInfo{ + UID: mockUID, + SessionsCap: JustInt32(10), + } + partialMarshalled, _ := json.Marshal(partialUserInfo) + req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled)) + assert.NoError(t, err) + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + + req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil) + assert.NoError(t, err) + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusCreated, rr.Code) + var got UserInfo + err = json.Unmarshal(rr.Body.Bytes(), &got) + assert.NoError(t, err) + + expected := mockUserInfo + expected.SessionsCap = partialUserInfo.SessionsCap + assert.EqualValues(t, expected, got) + }) + t.Run("empty parameter", func(t *testing.T) { req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled)) if err != nil { diff --git a/internal/server/usermanager/localmanager.go b/internal/server/usermanager/localmanager.go index 1689595..d60f62e 100644 --- a/internal/server/usermanager/localmanager.go +++ b/internal/server/usermanager/localmanager.go @@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo "User no longer exists", } responses = append(responses, resp) + continue } oldUp := int64(u64(bucket.Get([]byte("UpCredit")))) @@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) { err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { var uinfo UserInfo uinfo.UID = UID - uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) + uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap"))))) + uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate"))))) + uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate"))))) + uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit"))))) + uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit"))))) + uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime"))))) infos = append(infos, uinfo) return nil }) @@ -200,40 +201,52 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error) return ErrUserNotFound } uinfo.UID = UID - uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) - uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) - uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) - uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) - uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) - uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) + uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap"))))) + uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate"))))) + uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate"))))) + uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit"))))) + uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit"))))) + uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime"))))) return nil }) return } -func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) { +func (manager *localManager) WriteUserInfo(u UserInfo) (err error) { err = manager.db.Update(func(tx *bolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) + bucket, err := tx.CreateBucketIfNotExists(u.UID) if err != nil { return err } - if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { - return err + if u.SessionsCap != nil { + if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil { + return err + } } - if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { - return err + if u.UpRate != nil { + if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil { + return err + } } - if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { - return err + if u.DownRate != nil { + if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil { + return err + } } - if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { - return err + if u.UpCredit != nil { + if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil { + return err + } } - if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { - return err + if u.DownCredit != nil { + if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil { + return err + } } - if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { - return err + if u.ExpiryTime != nil { + if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil { + return err + } } return nil }) diff --git a/internal/server/usermanager/localmanager_test.go b/internal/server/usermanager/localmanager_test.go index 9e9370b..40873cc 100644 --- a/internal/server/usermanager/localmanager_test.go +++ b/internal/server/usermanager/localmanager_test.go @@ -3,6 +3,7 @@ package usermanager import ( "encoding/binary" "github.com/cbeuw/Cloak/internal/common" + "github.com/stretchr/testify/assert" "io/ioutil" "math/rand" "os" @@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockUserInfo = UserInfo{ UID: mockUID, - SessionsCap: 0, - UpRate: 0, - DownRate: 0, - UpCredit: 0, - DownCredit: 0, - ExpiryTime: 100, + SessionsCap: JustInt32(10), + UpRate: JustInt64(100), + DownRate: JustInt64(1000), + UpCredit: JustInt64(10000), + DownCredit: JustInt64(100000), + ExpiryTime: JustInt64(1000000), } func makeManager(t *testing.T) (mgr *localManager, cleaner func()) { @@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) { if err != nil { t.Error(err) } + + got, err := mgr.GetUserInfo(mockUID) + assert.NoError(t, err) + assert.EqualValues(t, mockUserInfo, got) + + /* Partial update */ + err = mgr.WriteUserInfo(UserInfo{ + UID: mockUID, + SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1), + }) + assert.NoError(t, err) + + expected := mockUserInfo + expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1) + got, err = mgr.GetUserInfo(mockUID) + assert.NoError(t, err) + assert.EqualValues(t, expected, got) } func TestLocalManager_GetUserInfo(t *testing.T) { @@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) { t.Run("update a field", func(t *testing.T) { _ = mgr.WriteUserInfo(mockUserInfo) updatedUserInfo := mockUserInfo - updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1 + updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1) err := mgr.WriteUserInfo(updatedUserInfo) if err != nil { @@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) { } } -var validUserInfo = UserInfo{ - UID: mockUID, - SessionsCap: 10, - UpRate: 100, - DownRate: 1000, - UpCredit: 10000, - DownCredit: 100000, - ExpiryTime: 1000000, -} +var validUserInfo = mockUserInfo func TestLocalManager_AuthenticateUser(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") @@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Error(err) } - if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate { + if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate { t.Error("wrong up or down rate") } }) @@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Run("expired user", func(t *testing.T) { expiredUserInfo := validUserInfo - expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix()) _ = mgr.WriteUserInfo(expiredUserInfo) @@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) { t.Run("no credit", func(t *testing.T) { creditlessUserInfo := validUserInfo - creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1 + creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1) _ = mgr.WriteUserInfo(creditlessUserInfo) @@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) { t.Run("expired user", func(t *testing.T) { expiredUserInfo := validUserInfo - expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() + expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix()) _ = mgr.WriteUserInfo(expiredUserInfo) err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) @@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) { t.Run("too many sessions", func(t *testing.T) { _ = mgr.WriteUserInfo(validUserInfo) - err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)}) + err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)}) if err != ErrSessionsCapReached { t.Error("session cap not reached") } @@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) { t.Error(err) } - if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage { + if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage { t.Error("up usage incorrect") } - if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage { + if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage { t.Error("down usage incorrect") } }) @@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) { UID: validUserInfo.UID, Active: true, NumSession: 1, - UpUsage: validUserInfo.UpCredit + 100, + UpUsage: *validUserInfo.UpCredit + 100, DownUsage: 0, Timestamp: mockWorldState.Now().Unix(), }, @@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) { Active: true, NumSession: 1, UpUsage: 0, - DownUsage: validUserInfo.DownCredit + 100, + DownUsage: *validUserInfo.DownCredit + 100, Timestamp: mockWorldState.Now().Unix(), }, }, {"expired", UserInfo{ UID: mockUID, - SessionsCap: 10, - UpRate: 0, - DownRate: 0, - UpCredit: 0, - DownCredit: 0, - ExpiryTime: -1, + SessionsCap: JustInt32(10), + UpRate: JustInt64(0), + DownRate: JustInt64(0), + UpCredit: JustInt64(0), + DownCredit: JustInt64(0), + ExpiryTime: JustInt64(-1), }, StatusUpdate{ UID: mockUserInfo.UID, @@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) { rand.Read(randUID) newUser := UserInfo{ UID: randUID, - SessionsCap: rand.Int31(), - UpRate: rand.Int63(), - DownRate: rand.Int63(), - UpCredit: rand.Int63(), - DownCredit: rand.Int63(), - ExpiryTime: rand.Int63(), + SessionsCap: JustInt32(rand.Int31()), + UpRate: JustInt64(rand.Int63()), + DownRate: JustInt64(rand.Int63()), + UpCredit: JustInt64(rand.Int63()), + DownCredit: JustInt64(rand.Int63()), + ExpiryTime: JustInt64(rand.Int63()), } users = append(users, newUser) wg.Add(1) diff --git a/internal/server/usermanager/usermanager.go b/internal/server/usermanager/usermanager.go index 7bf84d5..bb5456e 100644 --- a/internal/server/usermanager/usermanager.go +++ b/internal/server/usermanager/usermanager.go @@ -14,16 +14,23 @@ type StatusUpdate struct { Timestamp int64 } +type MaybeInt32 *int32 +type MaybeInt64 *int64 + type UserInfo struct { UID []byte - SessionsCap int32 - UpRate int64 - DownRate int64 - UpCredit int64 - DownCredit int64 - ExpiryTime int64 + SessionsCap MaybeInt32 + UpRate MaybeInt64 + DownRate MaybeInt64 + UpCredit MaybeInt64 + DownCredit MaybeInt64 + ExpiryTime MaybeInt64 } +func JustInt32(v int32) MaybeInt32 { return &v } + +func JustInt64(v int64) MaybeInt64 { return &v } + type StatusResponse struct { UID []byte Action int diff --git a/internal/server/userpanel_test.go b/internal/server/userpanel_test.go index f74d3e9..b28744c 100644 --- a/internal/server/userpanel_test.go +++ b/internal/server/userpanel_test.go @@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var validUserInfo = usermanager.UserInfo{ UID: mockUID, - SessionsCap: 10, - UpRate: 100, - DownRate: 1000, - UpCredit: 10000, - DownCredit: 100000, - ExpiryTime: 1000000, + SessionsCap: usermanager.JustInt32(10), + UpRate: usermanager.JustInt64(100), + DownRate: usermanager.JustInt64(1000), + UpCredit: usermanager.JustInt64(10000), + DownCredit: usermanager.JustInt64(100000), + ExpiryTime: usermanager.JustInt64(1000000), } func TestUserPanel_GetUser(t *testing.T) { @@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 { + if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 { t.Error("down credit incorrect update") } - if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 { + if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 { t.Error("up credit incorrect update") } @@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) { + if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) { t.Error("down credit incorrect update") } - if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) { + if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) { t.Error("up credit incorrect update") } }) @@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { t.Error(err) } - user.valve.AddTx(validUserInfo.DownCredit + 100) + user.valve.AddTx(*validUserInfo.DownCredit + 100) panel.updateUsageQueue() err = panel.commitUpdate() if err != nil { @@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) { } updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) - if updatedUinfo.DownCredit != -100 { + if *updatedUinfo.DownCredit != -100 { t.Error("down credit not updated correctly after the user has been terminated") } }) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index e4072e0..187507f 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -321,7 +321,7 @@ func TestTCPSingleplex(t *testing.T) { t.Fatal(err) } - const echoMsgLen = 16384 + const echoMsgLen = 1 << 16 go serveTCPEcho(proxyFromCkServerL) proxyConn1, err := proxyToCkClientD.Dial("", "") diff --git a/release.sh b/release.sh index bee82f8..61f317f 100755 --- a/release.sh +++ b/release.sh @@ -12,7 +12,7 @@ if [ -z "$v" ]; then fi output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v" -osarch="!darwin/arm !darwin/arm64 !darwin/386" +osarch="!darwin/arm !darwin/386" echo "Compiling:"