Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2021-04-06 20:28:34 +02:00
commit ba94c594f0
26 changed files with 377 additions and 540 deletions

View File

@ -135,6 +135,19 @@ encryption and authentication (via AEAD or similar techniques).**
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
the alternative domains.
Example:
```json
{
"ServerName": "bing.com",
"AlternativeNames": ["cloudflare.com", "github.com"]
}
```
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with

View File

@ -173,6 +173,11 @@ func main() {
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
seshMaker = func() *mux.Session {
authInfo := authInfo // copy the struct because we are overwriting SessionId
randByte := make([]byte, 1)
common.RandRead(authInfo.WorldState.Rand, randByte)
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID.
quad := make([]byte, 4)

View File

@ -1,4 +1,5 @@
// +build android
package main
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go

View File

@ -9,49 +9,27 @@ import (
func TestParseBindAddr(t *testing.T) {
t.Run("port only", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":443"})
if err != nil {
t.Error(err)
return
}
if addrs[0].String() != ":443" {
t.Errorf("expected %v got %v", ":443", addrs[0].String())
}
assert.NoError(t, err)
assert.Equal(t, ":443", addrs[0].String())
})
t.Run("specific address", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
if err != nil {
t.Error(err)
return
}
if addrs[0].String() != "192.168.1.123:443" {
t.Errorf("expected %v got %v", "192.168.1.123:443", addrs[0].String())
}
assert.NoError(t, err)
assert.Equal(t, "192.168.1.123:443", addrs[0].String())
})
t.Run("ipv6", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"[::]:443"})
if err != nil {
t.Error(err)
return
}
if addrs[0].String() != "[::]:443" {
t.Errorf("expected %v got %v", "[::]:443", addrs[0].String())
}
assert.NoError(t, err)
assert.Equal(t, "[::]:443", addrs[0].String())
})
t.Run("mixed", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
if err != nil {
t.Error(err)
return
}
if addrs[0].String() != ":80" {
t.Errorf("expected %v got %v", ":80", addrs[0].String())
}
if addrs[1].String() != "[::]:443" {
t.Errorf("expected %v got %v", "[::]:443", addrs[1].String())
}
assert.NoError(t, err)
assert.Equal(t, ":80", addrs[0].String())
assert.Equal(t, "[::]:443", addrs[1].String())
})
}

View File

@ -1,8 +1,8 @@
package client
import (
"bytes"
"encoding/hex"
"github.com/stretchr/testify/assert"
"testing"
)
@ -33,11 +33,6 @@ func TestMakeServerName(t *testing.T) {
}
for _, p := range pairs {
if !bytes.Equal(makeServerName(p.serverName), p.target) {
t.Error(
"for", p.serverName,
"expecting", p.target,
"got", makeServerName(p.serverName))
}
assert.Equal(t, p.target, makeServerName(p.serverName))
}
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/multiplex"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
@ -64,12 +65,8 @@ func TestMakeAuthenticationPayload(t *testing.T) {
for _, tc := range tests {
func() {
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
if payload != tc.expPayload {
t.Errorf("payload doesn't match:\nexp %v\ngot %v", tc.expPayload, payload)
}
if sharedSecret != tc.expSecret {
t.Errorf("secret doesn't match:\nexp %x\ngot %x", tc.expPayload, payload)
}
assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
}()
}
}

View File

@ -30,7 +30,7 @@ type RawConfig struct {
LocalPort string // jsonOptional
RemoteHost string // jsonOptional
RemotePort string // jsonOptional
AlternativeNames []string // jsonOptional
// defaults set in ProcessRawConfig
UDP bool // nullable
BrowserSig string // nullable
@ -51,6 +51,7 @@ type RemoteConnConfig struct {
type LocalConnConfig struct {
LocalAddr string
Timeout time.Duration
MockDomainList []string
}
type AuthInfo struct {
@ -94,6 +95,20 @@ func ssvToJson(ssv string) (ret []byte) {
}
key := sp[0]
value := sp[1]
if strings.HasPrefix(key, "AlternativeNames") {
switch strings.Contains(value, ",") {
case true:
domains := strings.Split(value, ",")
for index, domain := range domains {
domains[index] = `"` + domain + `"`
}
value = strings.Join(domains, ",")
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
case false:
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
}
continue
}
// JSON doesn't like quotation marks around int and bool
// This is extremely ugly but it's still better than writing a tokeniser
if elem(key, unquoted) {
@ -139,6 +154,8 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca
return nullErr("ServerName")
}
auth.MockDomain = raw.ServerName
local.MockDomainList = raw.AlternativeNames
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
if raw.ProxyMethod == "" {
return nullErr("ServerName")
}

View File

@ -2,6 +2,7 @@ package common
import (
"encoding/binary"
"errors"
"io"
"net"
"sync"
@ -94,6 +95,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in)
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
return 0, errors.New("message is too long")
}
writeBuf := tls.writeBufPool.Get().(*[]byte)
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
*writeBuf = append(*writeBuf, in...)

View File

@ -1,7 +1,7 @@
package multiplex
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
@ -11,13 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
_, err := pipe.Write(&Frame{Payload: b})
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
assert.NoError(t, err)
})
t.Run("simple read", func(t *testing.T) {
@ -25,50 +19,18 @@ func TestDatagramBuffer_RW(t *testing.T) {
_, _ = pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"expecting", b,
"got", b2,
)
}
if pipe.buf.Len() != 0 {
t.Error("buf len is not 0 after finished reading")
return
}
assert.NoError(t, err)
assert.Equal(t, len(b), n)
assert.Equal(t, b, b2)
assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
})
t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBufferedPipe()
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
if !toBeClosed {
t.Error("should be to be closed")
}
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !pipe.closed {
t.Error("expecting closed pipe, not closed")
}
assert.NoError(t, err)
assert.True(t, toBeClosed, "should be to be closed")
assert.True(t, pipe.closed, "pipe should be closed")
})
}
@ -81,30 +43,9 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
@ -114,27 +55,7 @@ func TestDatagramBuffer_CloseThenRead(t *testing.T) {
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
assert.NoError(t, err)
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
assert.Equal(t, b, b2)
}

View File

@ -108,9 +108,7 @@ func TestMultiplex(t *testing.T) {
streams := make([]net.Conn, numStreams)
for i := 0; i < numStreams; i++ {
stream, err := clientSession.OpenStream()
if err != nil {
t.Fatalf("failed to open stream: %v", err)
}
assert.NoError(t, err)
streams[i] = stream
}
@ -123,18 +121,11 @@ func TestMultiplex(t *testing.T) {
// close one stream
closing, streams := streams[0], streams[1:]
err := closing.Close()
if err != nil {
t.Errorf("couldn't close a stream")
}
assert.NoError(t, err, "couldn't close a stream")
_, err = closing.Write([]byte{0})
if err != ErrBrokenStream {
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
}
assert.Equal(t, ErrBrokenStream, err)
_, err = closing.Read(make([]byte, 1))
if err != ErrBrokenStream {
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
}
assert.Equal(t, ErrBrokenStream, err)
}
func TestMux_StreamClosing(t *testing.T) {
@ -146,20 +137,13 @@ func TestMux_StreamClosing(t *testing.T) {
recvBuf := make([]byte, 128)
toBeClosed, _ := clientSession.OpenStream()
_, err := toBeClosed.Write(testData) // should be echoed back
if err != nil {
t.Errorf("can't write to stream: %v", err)
}
assert.NoError(t, err, "couldn't write to a stream")
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
if err != nil {
t.Errorf("can't read anything before stream closed: %v", err)
}
assert.NoError(t, err, "can't read anything before stream closed")
_ = toBeClosed.Close()
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
if err != nil {
t.Errorf("can't read residual data on stream: %v", err)
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("incorrect data read back")
}
assert.NoError(t, err, "can't read residual data on stream")
assert.Equal(t, testData, recvBuf, "incorrect data read back")
}

View File

@ -138,7 +138,7 @@ func BenchmarkObfs(b *testing.B) {
testPayload,
}
obfsBuf := make([]byte, defaultSendRecvBufSize)
obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte
rand.Read(key[:])
@ -211,7 +211,7 @@ func BenchmarkDeobfs(b *testing.B) {
testPayload,
}
obfsBuf := make([]byte, defaultSendRecvBufSize)
obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte
rand.Read(key[:])

View File

@ -25,4 +25,4 @@ type recvBuffer interface {
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
// a panic will happen.
const recvBufferSizeLimit = defaultSendRecvBufSize << 12
const recvBufferSizeLimit = 1<<31 - 1

View File

@ -14,9 +14,8 @@ import (
const (
acceptBacklog = 1024
// TODO: will this be a signature?
defaultSendRecvBufSize = 20480
defaultInactivityTimeout = 30 * time.Second
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
)
var ErrBrokenSession = errors.New("broken session")
@ -24,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type switchboardStrategy int
type SessionConfig struct {
Obfuscator
@ -40,12 +37,6 @@ type SessionConfig struct {
// maximum size of an obfuscated frame, including headers and overhead
MsgOnWireSizeLimit int
// StreamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
StreamSendBufferSize int
// ConnReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
ConnReceiveBufferSize int
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
InactivityTimeout time.Duration
}
@ -82,11 +73,17 @@ type Session struct {
closed uint32
terminalMsg atomic.Value
terminalMsgSetter sync.Once
terminalMsg string
// the max size passed to Write calls before it splits it into multiple frames
// i.e. the max size a piece of data can fit into a Frame.Payload
maxStreamUnitWrite int
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
streamSendBufferSize int
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
connReceiveBufferSize int
}
func MakeSession(id uint32, config SessionConfig) *Session {
@ -103,23 +100,19 @@ func MakeSession(id uint32, config SessionConfig) *Session {
if config.Valve == nil {
sesh.Valve = UNLIMITED_VALVE
}
if config.StreamSendBufferSize <= 0 {
sesh.StreamSendBufferSize = defaultSendRecvBufSize
}
if config.ConnReceiveBufferSize <= 0 {
sesh.ConnReceiveBufferSize = defaultSendRecvBufSize
}
if config.MsgOnWireSizeLimit <= 0 {
sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024
sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
}
if config.InactivityTimeout == 0 {
sesh.InactivityTimeout = defaultInactivityTimeout
}
// todo: validation. this must be smaller than StreamSendBufferSize
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.maxOverhead
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit
sesh.connReceiveBufferSize = 20480 // for backwards compatibility
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
b := make([]byte, sesh.StreamSendBufferSize)
b := make([]byte, sesh.streamSendBufferSize)
return &b
}}
@ -187,7 +180,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
}
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
_ = s.recvBuf.Close() // recvBuf.Close should not return error
if active {
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
@ -271,16 +264,13 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
}
func (sesh *Session) SetTerminalMsg(msg string) {
sesh.terminalMsg.Store(msg)
sesh.terminalMsgSetter.Do(func() {
sesh.terminalMsg = msg
})
}
func (sesh *Session) TerminalMsg() string {
msg := sesh.terminalMsg.Load()
if msg != nil {
return msg.(string)
} else {
return ""
}
return sesh.terminalMsg
}
func (sesh *Session) closeSession() error {
@ -292,14 +282,12 @@ func (sesh *Session) closeSession() error {
sesh.streamsM.Lock()
close(sesh.acceptCh)
for id, stream := range sesh.streams {
if stream == nil {
continue
}
atomic.StoreUint32(&stream.closed, 1)
_ = stream.getRecvBuf().Close() // will not block
if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
_ = stream.recvBuf.Close() // will not block
delete(sesh.streams, id)
sesh.streamCountDecr()
}
}
sesh.streamsM.Unlock()
return nil
}
@ -339,7 +327,7 @@ func (sesh *Session) Close() error {
if err != nil {
return err
}
_, err = sesh.sb.send((*buf)[:i], new(uint32))
_, err = sesh.sb.send((*buf)[:i], new(net.Conn))
if err != nil {
return err
}

View File

@ -534,7 +534,7 @@ func TestSession_timeoutAfter(t *testing.T) {
func BenchmarkRecvDataFromRemote(b *testing.B) {
testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload)
f := &Frame{
f := Frame{
1,
0,
0,
@ -544,12 +544,13 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic
const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic
for name, ep := range encryptionMethods {
ep := ep
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) {
f := f
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig)
@ -561,7 +562,7 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(f, obfsBuf, 0)
n, _ := sesh.obfuscate(&f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n]
f.Seq++
}

View File

@ -23,9 +23,8 @@ type Stream struct {
session *Session
allocIdempot sync.Once
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
// been read by the consumer through Read or WriteTo. Lazily allocated
// been read by the consumer through Read or WriteTo.
recvBuf recvBuffer
writingM sync.Mutex
@ -40,7 +39,7 @@ type Stream struct {
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
// so it won't have to use its priority queue to sort it.
// This is not used in unordered connection mode
assignedConnId uint32
assignedConn net.Conn
readFromTimeout time.Duration
}
@ -56,25 +55,20 @@ func makeStream(sesh *Session, id uint32) *Stream {
},
}
if sesh.Unordered {
stream.recvBuf = NewDatagramBufferedPipe()
} else {
stream.recvBuf = NewStreamBuffer()
}
return stream
}
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) getRecvBuf() recvBuffer {
s.allocIdempot.Do(func() {
if s.session.Unordered {
s.recvBuf = NewDatagramBufferedPipe()
} else {
s.recvBuf = NewStreamBuffer()
}
})
return s.recvBuf
}
// receive a readily deobfuscated Frame so its payload can later be Read
func (s *Stream) recvFrame(frame *Frame) error {
toBeClosed, err := s.getRecvBuf().Write(frame)
toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed {
err = s.passiveClose()
if errors.Is(err, errRepeatStreamClosing) {
@ -93,7 +87,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
return 0, nil
}
n, err = s.getRecvBuf().Read(buf)
n, err = s.recvBuf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF {
return n, ErrBrokenStream
@ -104,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
// WriteTo continuously write data Stream has received into the writer w.
func (s *Stream) WriteTo(w io.Writer) (int64, error) {
// will keep writing until the underlying buffer is closed
n, err := s.getRecvBuf().WriteTo(w)
n, err := s.recvBuf.WriteTo(w)
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF {
return n, ErrBrokenStream
@ -119,7 +113,7 @@ func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
return err
}
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
@ -215,8 +209,8 @@ func (s *Stream) Close() error {
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
var errNotImplemented = errors.New("Not implemented")

View File

@ -1,7 +1,7 @@
package multiplex
import (
"bytes"
"github.com/stretchr/testify/assert"
"math/rand"
"testing"
"time"
@ -13,49 +13,15 @@ func TestPipeRW(t *testing.T) {
pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b)
if n != len(b) {
t.Error(
"For", "number of bytes written",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple write",
"expecting", "nil error",
"got", err,
)
return
}
assert.NoError(t, err, "simple write")
assert.Equal(t, len(b), n, "number of bytes written")
b2 := make([]byte, len(b))
n, err = pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func TestReadBlock(t *testing.T) {
@ -67,30 +33,10 @@ func TestReadBlock(t *testing.T) {
}()
b2 := make([]byte, len(b))
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read after block",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
assert.NoError(t, err, "blocked read")
assert.Equal(t, len(b), n, "number of bytes read after block")
assert.Equal(t, b, b2)
}
func TestPartialRead(t *testing.T) {
@ -99,54 +45,17 @@ func TestPartialRead(t *testing.T) {
pipe.Write(b)
b1 := make([]byte, 1)
n, err := pipe.Read(b1)
if n != len(b1) {
t.Error(
"For", "number of bytes in partial read of 1",
"expecting", len(b1),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "partial read of 1",
"expecting", "nil error",
"got", err,
)
return
}
if b1[0] != b[0] {
t.Error(
"For", "partial read of 1",
"expecting", b[0],
"got", b1[0],
)
}
assert.NoError(t, err, "partial read of 1")
assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
assert.Equal(t, b[0], b1[0])
b2 := make([]byte, 2)
n, err = pipe.Read(b2)
if n != len(b2) {
t.Error(
"For", "number of bytes in partial read of 2",
"expecting", len(b2),
"got", n,
)
}
if err != nil {
t.Error(
"For", "partial read of 2",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b[1:], b2) {
t.Error(
"For", "partial read of 2",
"expecting", b[1:],
"got", b2,
)
return
}
assert.NoError(t, err, "partial read of 2")
assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
assert.Equal(t, b[1:], b2)
}
func TestReadAfterClose(t *testing.T) {
@ -156,29 +65,10 @@ func TestReadAfterClose(t *testing.T) {
b2 := make([]byte, len(b))
pipe.Close()
n, err := pipe.Read(b2)
if n != len(b) {
t.Error(
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
assert.NoError(t, err, "simple read")
assert.Equal(t, len(b), n, "number of bytes read")
assert.Equal(t, b, b2)
}
func BenchmarkBufferedPipe_RW(b *testing.B) {

View File

@ -10,9 +10,11 @@ import (
"time"
)
type switchboardStrategy int
const (
FIXED_CONN_MAPPING switchboardStrategy = iota
UNIFORM_SPREAD
fixedConnMapping switchboardStrategy = iota
uniformSpread
)
// switchboard represents the connection pool. It is responsible for managing
@ -28,10 +30,8 @@ type switchboard struct {
valve Valve
strategy switchboardStrategy
// map of connId to net.Conn
conns sync.Map
numConns uint32
nextConnId uint32
connsCount uint32
randPool sync.Pool
broken uint32
@ -41,15 +41,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
var strategy switchboardStrategy
if sesh.Unordered {
log.Debug("Connection is unordered")
strategy = UNIFORM_SPREAD
strategy = uniformSpread
} else {
strategy = FIXED_CONN_MAPPING
strategy = fixedConnMapping
}
sb := &switchboard{
session: sesh,
strategy: strategy,
valve: sesh.Valve,
nextConnId: 1,
randPool: sync.Pool{New: func() interface{} {
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
}},
@ -59,88 +58,85 @@ func makeSwitchboard(sesh *Session) *switchboard {
var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) connsCount() int {
return int(atomic.LoadUint32(&sb.numConns))
}
func (sb *switchboard) addConn(conn net.Conn) {
connId := atomic.AddUint32(&sb.nextConnId, 1) - 1
atomic.AddUint32(&sb.numConns, 1)
sb.conns.Store(connId, conn)
go sb.deplex(connId, conn)
atomic.AddUint32(&sb.connsCount, 1)
sb.conns.Store(conn, conn)
go sb.deplex(conn)
}
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable
func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) {
sb.valve.txWait(len(data))
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 {
if atomic.LoadUint32(&sb.broken) == 1 {
return 0, errBrokenSwitchboard
}
var conn net.Conn
switch sb.strategy {
case UNIFORM_SPREAD:
_, conn, err = sb.pickRandConn()
case uniformSpread:
conn, err = sb.pickRandConn()
if err != nil {
return 0, errBrokenSwitchboard
}
case FIXED_CONN_MAPPING:
connI, ok := sb.conns.Load(*connId)
if ok {
conn = connI.(net.Conn)
} else {
var newConnId uint32
newConnId, conn, err = sb.pickRandConn()
n, err = conn.Write(data)
if err != nil {
return 0, errBrokenSwitchboard
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
*connId = newConnId
case fixedConnMapping:
conn = *assignedConn
if conn == nil {
conn, err = sb.pickRandConn()
if err != nil {
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
sb.session.passiveClose()
return 0, err
}
*assignedConn = conn
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
default:
return 0, errors.New("unsupported traffic distribution strategy")
}
n, err = conn.Write(data)
if err != nil {
sb.conns.Delete(*connId)
sb.session.SetTerminalMsg("failed to write to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
sb.valve.AddTx(int64(n))
return n, nil
}
// returns a random connId
func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
connCount := sb.connsCount()
if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 {
return 0, nil, errBrokenSwitchboard
func (sb *switchboard) pickRandConn() (net.Conn, error) {
if atomic.LoadUint32(&sb.broken) == 1 {
return nil, errBrokenSwitchboard
}
connsCount := atomic.LoadUint32(&sb.connsCount)
if connsCount == 0 {
return nil, errBrokenSwitchboard
}
// there is no guarantee that sb.conns still has the same amount of entries
// between the count loop and the pick loop
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked
var id uint32
var conn net.Conn
randReader := sb.randPool.Get().(*rand.Rand)
r := randReader.Intn(connCount)
r := randReader.Intn(int(connsCount))
sb.randPool.Put(randReader)
var c int
sb.conns.Range(func(connIdI, connI interface{}) bool {
var ret net.Conn
sb.conns.Range(func(_, conn interface{}) bool {
if r == c {
id = connIdI.(uint32)
conn = connI.(net.Conn)
ret = conn.(net.Conn)
return false
}
c++
return true
})
// if len(sb.conns) is 0
if conn == nil {
return 0, nil, errBrokenSwitchboard
}
return id, conn, nil
return ret, nil
}
// actively triggered by session.Close()
@ -148,26 +144,24 @@ func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return
}
sb.conns.Range(func(key, connI interface{}) bool {
conn := connI.(net.Conn)
conn.Close()
sb.conns.Delete(key)
sb.conns.Range(func(_, conn interface{}) bool {
conn.(net.Conn).Close()
sb.conns.Delete(conn)
atomic.AddUint32(&sb.connsCount, ^uint32(0))
return true
})
}
// deplex function costantly reads from a TCP connection
func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
func (sb *switchboard) deplex(conn net.Conn) {
defer conn.Close()
buf := make([]byte, sb.session.ConnReceiveBufferSize)
buf := make([]byte, sb.session.connReceiveBufferSize)
for {
n, err := conn.Read(buf)
sb.valve.rxWait(n)
sb.valve.AddRx(int64(n))
if err != nil {
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
sb.conns.Delete(connId)
atomic.AddUint32(&sb.numConns, ^uint32(0))
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
sb.session.passiveClose()
return

View File

@ -5,6 +5,7 @@ import (
"github.com/stretchr/testify/assert"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"
)
@ -14,14 +15,14 @@ func TestSwitchboard_Send(t *testing.T) {
sesh := MakeSession(0, seshConfig)
hole0 := connutil.Discard()
sesh.sb.addConn(hole0)
connId, _, err := sesh.sb.pickRandConn()
conn, err := sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
data := make([]byte, 1000)
rand.Read(data)
_, err = sesh.sb.send(data, &connId)
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
@ -29,23 +30,23 @@ func TestSwitchboard_Send(t *testing.T) {
hole1 := connutil.Discard()
sesh.sb.addConn(hole1)
connId, _, err = sesh.sb.pickRandConn()
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &connId)
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
}
connId, _, err = sesh.sb.pickRandConn()
conn, err = sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
}
_, err = sesh.sb.send(data, &connId)
_, err = sesh.sb.send(data, &conn)
if err != nil {
t.Error(err)
return
@ -71,7 +72,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
seshConfig := SessionConfig{}
sesh := MakeSession(0, seshConfig)
sesh.sb.addConn(hole)
connId, _, err := sesh.sb.pickRandConn()
conn, err := sesh.sb.pickRandConn()
if err != nil {
b.Error("failed to get a random conn", err)
return
@ -81,7 +82,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
b.SetBytes(int64(len(data)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.sb.send(data, &connId)
sesh.sb.send(data, &conn)
}
}
@ -92,7 +93,7 @@ func TestSwitchboard_TxCredit(t *testing.T) {
sesh := MakeSession(0, seshConfig)
hole := connutil.Discard()
sesh.sb.addConn(hole)
connId, _, err := sesh.sb.pickRandConn()
conn, err := sesh.sb.pickRandConn()
if err != nil {
t.Error("failed to get a random conn", err)
return
@ -100,10 +101,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
data := make([]byte, 1000)
rand.Read(data)
t.Run("FIXED CONN MAPPING", func(t *testing.T) {
t.Run("fixed conn mapping", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = FIXED_CONN_MAPPING
n, err := sesh.sb.send(data[:10], &connId)
sesh.sb.strategy = fixedConnMapping
n, err := sesh.sb.send(data[:10], &conn)
if err != nil {
t.Error(err)
return
@ -116,10 +117,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
t.Error("tx credit didn't increase by 10")
}
})
t.Run("UNIFORM", func(t *testing.T) {
t.Run("uniform spread", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = UNIFORM_SPREAD
n, err := sesh.sb.send(data[:10], &connId)
sesh.sb.strategy = uniformSpread
n, err := sesh.sb.send(data[:10], &conn)
if err != nil {
t.Error(err)
return
@ -173,13 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
}
wg.Wait()
if sesh.sb.connsCount() != 1000 {
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
t.Error("connsCount incorrect")
}
sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool {
return sesh.sb.connsCount() == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", sesh.sb.connsCount())
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
}

View File

@ -2,7 +2,7 @@ swagger: '2.0'
info:
description: |
This is the API of Cloak server
version: 1.0.0
version: 0.0.2
title: Cloak Server
contact:
email: cbeuw.andy@gmail.com
@ -12,8 +12,6 @@ info:
# host: petstore.swagger.io
# basePath: /v2
tags:
- name: admin
description: Endpoints used by the host administrators
- name: users
description: Operations related to user controls by admin
# schemes:
@ -22,7 +20,6 @@ paths:
/admin/users:
get:
tags:
- admin
- users
summary: Show all users
description: Returns an array of all UserInfo
@ -41,7 +38,6 @@ paths:
/admin/users/{UID}:
get:
tags:
- admin
- users
summary: Show userinfo by UID
description: Returns a UserInfo object
@ -68,7 +64,6 @@ paths:
description: internal error
post:
tags:
- admin
- users
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
operationId: writeUserInfo
@ -100,7 +95,6 @@ paths:
description: internal error
delete:
tags:
- admin
- users
summary: Deletes a user
operationId: deleteUser

View File

@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) {
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
})
t.Run("partial update", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
assert.NoError(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
partialUserInfo := UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
}
partialMarshalled, _ := json.Marshal(partialUserInfo)
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
var got UserInfo
err = json.Unmarshal(rr.Body.Bytes(), &got)
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = partialUserInfo.SessionsCap
assert.EqualValues(t, expected, got)
})
t.Run("empty parameter", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
if err != nil {

View File

@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
"User no longer exists",
}
responses = append(responses, resp)
continue
}
oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo
uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap"))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate"))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
infos = append(infos, uinfo)
return nil
})
@ -200,41 +201,53 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error)
return ErrUserNotFound
}
uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap"))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate"))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate"))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
return nil
})
return
}
func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) {
func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID)
bucket, err := tx.CreateBucketIfNotExists(u.UID)
if err != nil {
return err
}
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil {
if u.SessionsCap != nil {
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
return err
}
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
}
if u.UpRate != nil {
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
return err
}
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
}
if u.DownRate != nil {
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
return err
}
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
}
if u.UpCredit != nil {
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
return err
}
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
}
if u.DownCredit != nil {
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
return err
}
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
}
if u.ExpiryTime != nil {
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
return err
}
}
return nil
})
return

View File

@ -3,6 +3,7 @@ package usermanager
import (
"encoding/binary"
"github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
"io/ioutil"
"math/rand"
"os"
@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var mockUserInfo = UserInfo{
UID: mockUID,
SessionsCap: 0,
UpRate: 0,
DownRate: 0,
UpCredit: 0,
DownCredit: 0,
ExpiryTime: 100,
SessionsCap: JustInt32(10),
UpRate: JustInt64(100),
DownRate: JustInt64(1000),
UpCredit: JustInt64(10000),
DownCredit: JustInt64(100000),
ExpiryTime: JustInt64(1000000),
}
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) {
if err != nil {
t.Error(err)
}
got, err := mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, mockUserInfo, got)
/* Partial update */
err = mgr.WriteUserInfo(UserInfo{
UID: mockUID,
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
})
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
got, err = mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, expected, got)
}
func TestLocalManager_GetUserInfo(t *testing.T) {
@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) {
t.Run("update a field", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo)
updatedUserInfo := mockUserInfo
updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1
updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
err := mgr.WriteUserInfo(updatedUserInfo)
if err != nil {
@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) {
}
}
var validUserInfo = UserInfo{
UID: mockUID,
SessionsCap: 10,
UpRate: 100,
DownRate: 1000,
UpCredit: 10000,
DownCredit: 100000,
ExpiryTime: 1000000,
}
var validUserInfo = mockUserInfo
func TestLocalManager_AuthenticateUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Error(err)
}
if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate {
if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
t.Error("wrong up or down rate")
}
})
@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("no credit", func(t *testing.T) {
creditlessUserInfo := validUserInfo
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
_ = mgr.WriteUserInfo(creditlessUserInfo)
@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo)
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("too many sessions", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)})
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
if err != ErrSessionsCapReached {
t.Error("session cap not reached")
}
@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) {
t.Error(err)
}
if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage {
if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
t.Error("up usage incorrect")
}
if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage {
if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
t.Error("down usage incorrect")
}
})
@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) {
UID: validUserInfo.UID,
Active: true,
NumSession: 1,
UpUsage: validUserInfo.UpCredit + 100,
UpUsage: *validUserInfo.UpCredit + 100,
DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(),
},
@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) {
Active: true,
NumSession: 1,
UpUsage: 0,
DownUsage: validUserInfo.DownCredit + 100,
DownUsage: *validUserInfo.DownCredit + 100,
Timestamp: mockWorldState.Now().Unix(),
},
},
{"expired",
UserInfo{
UID: mockUID,
SessionsCap: 10,
UpRate: 0,
DownRate: 0,
UpCredit: 0,
DownCredit: 0,
ExpiryTime: -1,
SessionsCap: JustInt32(10),
UpRate: JustInt64(0),
DownRate: JustInt64(0),
UpCredit: JustInt64(0),
DownCredit: JustInt64(0),
ExpiryTime: JustInt64(-1),
},
StatusUpdate{
UID: mockUserInfo.UID,
@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) {
rand.Read(randUID)
newUser := UserInfo{
UID: randUID,
SessionsCap: rand.Int31(),
UpRate: rand.Int63(),
DownRate: rand.Int63(),
UpCredit: rand.Int63(),
DownCredit: rand.Int63(),
ExpiryTime: rand.Int63(),
SessionsCap: JustInt32(rand.Int31()),
UpRate: JustInt64(rand.Int63()),
DownRate: JustInt64(rand.Int63()),
UpCredit: JustInt64(rand.Int63()),
DownCredit: JustInt64(rand.Int63()),
ExpiryTime: JustInt64(rand.Int63()),
}
users = append(users, newUser)
wg.Add(1)

View File

@ -14,16 +14,23 @@ type StatusUpdate struct {
Timestamp int64
}
type MaybeInt32 *int32
type MaybeInt64 *int64
type UserInfo struct {
UID []byte
SessionsCap int32
UpRate int64
DownRate int64
UpCredit int64
DownCredit int64
ExpiryTime int64
SessionsCap MaybeInt32
UpRate MaybeInt64
DownRate MaybeInt64
UpCredit MaybeInt64
DownCredit MaybeInt64
ExpiryTime MaybeInt64
}
func JustInt32(v int32) MaybeInt32 { return &v }
func JustInt64(v int64) MaybeInt64 { return &v }
type StatusResponse struct {
UID []byte
Action int

View File

@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var validUserInfo = usermanager.UserInfo{
UID: mockUID,
SessionsCap: 10,
UpRate: 100,
DownRate: 1000,
UpCredit: 10000,
DownCredit: 100000,
ExpiryTime: 1000000,
SessionsCap: usermanager.JustInt32(10),
UpRate: usermanager.JustInt64(100),
DownRate: usermanager.JustInt64(1000),
UpCredit: usermanager.JustInt64(10000),
DownCredit: usermanager.JustInt64(100000),
ExpiryTime: usermanager.JustInt64(1000000),
}
func TestUserPanel_GetUser(t *testing.T) {
@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 {
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
t.Error("down credit incorrect update")
}
if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 {
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
t.Error("up credit incorrect update")
}
@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
}
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) {
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
t.Error("down credit incorrect update")
}
if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) {
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
t.Error("up credit incorrect update")
}
})
@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
t.Error(err)
}
user.valve.AddTx(validUserInfo.DownCredit + 100)
user.valve.AddTx(*validUserInfo.DownCredit + 100)
panel.updateUsageQueue()
err = panel.commitUpdate()
if err != nil {
@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
}
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != -100 {
if *updatedUinfo.DownCredit != -100 {
t.Error("down credit not updated correctly after the user has been terminated")
}
})

View File

@ -321,7 +321,7 @@ func TestTCPSingleplex(t *testing.T) {
t.Fatal(err)
}
const echoMsgLen = 16384
const echoMsgLen = 1 << 16
go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "")

View File

@ -12,7 +12,7 @@ if [ -z "$v" ]; then
fi
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
osarch="!darwin/arm !darwin/arm64 !darwin/386"
osarch="!darwin/arm !darwin/386"
echo "Compiling:"