Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2021-04-06 20:28:34 +02:00
commit ba94c594f0
26 changed files with 377 additions and 540 deletions

View File

@ -135,6 +135,19 @@ encryption and authentication (via AEAD or similar techniques).**
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should `ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to. match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
the alternative domains.
Example:
```json
{
"ServerName": "bing.com",
"AlternativeNames": ["cloudflare.com", "github.com"]
}
```
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only `CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with

View File

@ -173,6 +173,11 @@ func main() {
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod) log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
seshMaker = func() *mux.Session { seshMaker = func() *mux.Session {
authInfo := authInfo // copy the struct because we are overwriting SessionId authInfo := authInfo // copy the struct because we are overwriting SessionId
randByte := make([]byte, 1)
common.RandRead(authInfo.WorldState.Rand, randByte)
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
// sessionID is usergenerated. There shouldn't be a security concern because the scope of // sessionID is usergenerated. There shouldn't be a security concern because the scope of
// sessionID is limited to its UID. // sessionID is limited to its UID.
quad := make([]byte, 4) quad := make([]byte, 4)

View File

@ -1,4 +1,5 @@
// +build android // +build android
package main package main
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go // Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go

View File

@ -9,49 +9,27 @@ import (
func TestParseBindAddr(t *testing.T) { func TestParseBindAddr(t *testing.T) {
t.Run("port only", func(t *testing.T) { t.Run("port only", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":443"}) addrs, err := resolveBindAddr([]string{":443"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, ":443", addrs[0].String())
return
}
if addrs[0].String() != ":443" {
t.Errorf("expected %v got %v", ":443", addrs[0].String())
}
}) })
t.Run("specific address", func(t *testing.T) { t.Run("specific address", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"}) addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, "192.168.1.123:443", addrs[0].String())
return
}
if addrs[0].String() != "192.168.1.123:443" {
t.Errorf("expected %v got %v", "192.168.1.123:443", addrs[0].String())
}
}) })
t.Run("ipv6", func(t *testing.T) { t.Run("ipv6", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{"[::]:443"}) addrs, err := resolveBindAddr([]string{"[::]:443"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, "[::]:443", addrs[0].String())
return
}
if addrs[0].String() != "[::]:443" {
t.Errorf("expected %v got %v", "[::]:443", addrs[0].String())
}
}) })
t.Run("mixed", func(t *testing.T) { t.Run("mixed", func(t *testing.T) {
addrs, err := resolveBindAddr([]string{":80", "[::]:443"}) addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, ":80", addrs[0].String())
return assert.Equal(t, "[::]:443", addrs[1].String())
}
if addrs[0].String() != ":80" {
t.Errorf("expected %v got %v", ":80", addrs[0].String())
}
if addrs[1].String() != "[::]:443" {
t.Errorf("expected %v got %v", "[::]:443", addrs[1].String())
}
}) })
} }

View File

@ -1,8 +1,8 @@
package client package client
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -33,11 +33,6 @@ func TestMakeServerName(t *testing.T) {
} }
for _, p := range pairs { for _, p := range pairs {
if !bytes.Equal(makeServerName(p.serverName), p.target) { assert.Equal(t, p.target, makeServerName(p.serverName))
t.Error(
"for", p.serverName,
"expecting", p.target,
"got", makeServerName(p.serverName))
}
} }
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/multiplex" "github.com/cbeuw/Cloak/internal/multiplex"
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -64,12 +65,8 @@ func TestMakeAuthenticationPayload(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
func() { func() {
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo) payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
if payload != tc.expPayload { assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
t.Errorf("payload doesn't match:\nexp %v\ngot %v", tc.expPayload, payload) assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
}
if sharedSecret != tc.expSecret {
t.Errorf("secret doesn't match:\nexp %x\ngot %x", tc.expPayload, payload)
}
}() }()
} }
} }

View File

@ -30,7 +30,7 @@ type RawConfig struct {
LocalPort string // jsonOptional LocalPort string // jsonOptional
RemoteHost string // jsonOptional RemoteHost string // jsonOptional
RemotePort string // jsonOptional RemotePort string // jsonOptional
AlternativeNames []string // jsonOptional
// defaults set in ProcessRawConfig // defaults set in ProcessRawConfig
UDP bool // nullable UDP bool // nullable
BrowserSig string // nullable BrowserSig string // nullable
@ -51,6 +51,7 @@ type RemoteConnConfig struct {
type LocalConnConfig struct { type LocalConnConfig struct {
LocalAddr string LocalAddr string
Timeout time.Duration Timeout time.Duration
MockDomainList []string
} }
type AuthInfo struct { type AuthInfo struct {
@ -94,6 +95,20 @@ func ssvToJson(ssv string) (ret []byte) {
} }
key := sp[0] key := sp[0]
value := sp[1] value := sp[1]
if strings.HasPrefix(key, "AlternativeNames") {
switch strings.Contains(value, ",") {
case true:
domains := strings.Split(value, ",")
for index, domain := range domains {
domains[index] = `"` + domain + `"`
}
value = strings.Join(domains, ",")
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
case false:
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
}
continue
}
// JSON doesn't like quotation marks around int and bool // JSON doesn't like quotation marks around int and bool
// This is extremely ugly but it's still better than writing a tokeniser // This is extremely ugly but it's still better than writing a tokeniser
if elem(key, unquoted) { if elem(key, unquoted) {
@ -139,6 +154,8 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca
return nullErr("ServerName") return nullErr("ServerName")
} }
auth.MockDomain = raw.ServerName auth.MockDomain = raw.ServerName
local.MockDomainList = raw.AlternativeNames
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
if raw.ProxyMethod == "" { if raw.ProxyMethod == "" {
return nullErr("ServerName") return nullErr("ServerName")
} }

View File

@ -2,6 +2,7 @@ package common
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
@ -94,6 +95,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
func (tls *TLSConn) Write(in []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) {
msgLen := len(in) msgLen := len(in)
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
return 0, errors.New("message is too long")
}
writeBuf := tls.writeBufPool.Get().(*[]byte) writeBuf := tls.writeBufPool.Get().(*[]byte)
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF)) *writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
*writeBuf = append(*writeBuf, in...) *writeBuf = append(*writeBuf, in...)

View File

@ -1,7 +1,7 @@
package multiplex package multiplex
import ( import (
"bytes" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -11,13 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("simple write", func(t *testing.T) { t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
_, err := pipe.Write(&Frame{Payload: b}) _, err := pipe.Write(&Frame{Payload: b})
if err != nil { assert.NoError(t, err)
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
}) })
t.Run("simple read", func(t *testing.T) { t.Run("simple read", func(t *testing.T) {
@ -25,50 +19,18 @@ func TestDatagramBuffer_RW(t *testing.T) {
_, _ = pipe.Write(&Frame{Payload: b}) _, _ = pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { assert.NoError(t, err)
t.Error( assert.Equal(t, len(b), n)
"For", "number of bytes read", assert.Equal(t, b, b2)
"expecting", len(b), assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
"got", n,
)
return
}
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"expecting", b,
"got", b2,
)
}
if pipe.buf.Len() != 0 {
t.Error("buf len is not 0 after finished reading")
return
}
}) })
t.Run("writing closing frame", func(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream}) toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
if !toBeClosed { assert.NoError(t, err)
t.Error("should be to be closed") assert.True(t, toBeClosed, "should be to be closed")
} assert.True(t, pipe.closed, "pipe should be closed")
if err != nil {
t.Error(
"expecting", "nil error",
"got", err,
)
return
}
if !pipe.closed {
t.Error("expecting closed pipe, not closed")
}
}) })
} }
@ -81,30 +43,9 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
}() }()
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { assert.NoError(t, err)
t.Error( assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
"For", "number of bytes read after block", assert.Equal(t, b, b2)
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
} }
func TestDatagramBuffer_CloseThenRead(t *testing.T) { func TestDatagramBuffer_CloseThenRead(t *testing.T) {
@ -114,27 +55,7 @@ func TestDatagramBuffer_CloseThenRead(t *testing.T) {
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
pipe.Close() pipe.Close()
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { assert.NoError(t, err)
t.Error( assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
"For", "number of bytes read", assert.Equal(t, b, b2)
"expecting", len(b),
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
} }

View File

@ -108,9 +108,7 @@ func TestMultiplex(t *testing.T) {
streams := make([]net.Conn, numStreams) streams := make([]net.Conn, numStreams)
for i := 0; i < numStreams; i++ { for i := 0; i < numStreams; i++ {
stream, err := clientSession.OpenStream() stream, err := clientSession.OpenStream()
if err != nil { assert.NoError(t, err)
t.Fatalf("failed to open stream: %v", err)
}
streams[i] = stream streams[i] = stream
} }
@ -123,18 +121,11 @@ func TestMultiplex(t *testing.T) {
// close one stream // close one stream
closing, streams := streams[0], streams[1:] closing, streams := streams[0], streams[1:]
err := closing.Close() err := closing.Close()
if err != nil { assert.NoError(t, err, "couldn't close a stream")
t.Errorf("couldn't close a stream")
}
_, err = closing.Write([]byte{0}) _, err = closing.Write([]byte{0})
if err != ErrBrokenStream { assert.Equal(t, ErrBrokenStream, err)
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
}
_, err = closing.Read(make([]byte, 1)) _, err = closing.Read(make([]byte, 1))
if err != ErrBrokenStream { assert.Equal(t, ErrBrokenStream, err)
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
}
} }
func TestMux_StreamClosing(t *testing.T) { func TestMux_StreamClosing(t *testing.T) {
@ -146,20 +137,13 @@ func TestMux_StreamClosing(t *testing.T) {
recvBuf := make([]byte, 128) recvBuf := make([]byte, 128)
toBeClosed, _ := clientSession.OpenStream() toBeClosed, _ := clientSession.OpenStream()
_, err := toBeClosed.Write(testData) // should be echoed back _, err := toBeClosed.Write(testData) // should be echoed back
if err != nil { assert.NoError(t, err, "couldn't write to a stream")
t.Errorf("can't write to stream: %v", err)
}
_, err = io.ReadFull(toBeClosed, recvBuf[:1]) _, err = io.ReadFull(toBeClosed, recvBuf[:1])
if err != nil { assert.NoError(t, err, "can't read anything before stream closed")
t.Errorf("can't read anything before stream closed: %v", err)
}
_ = toBeClosed.Close() _ = toBeClosed.Close()
_, err = io.ReadFull(toBeClosed, recvBuf[1:]) _, err = io.ReadFull(toBeClosed, recvBuf[1:])
if err != nil { assert.NoError(t, err, "can't read residual data on stream")
t.Errorf("can't read residual data on stream: %v", err) assert.Equal(t, testData, recvBuf, "incorrect data read back")
}
if !bytes.Equal(testData, recvBuf) {
t.Errorf("incorrect data read back")
}
} }

View File

@ -138,7 +138,7 @@ func BenchmarkObfs(b *testing.B) {
testPayload, testPayload,
} }
obfsBuf := make([]byte, defaultSendRecvBufSize) obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte var key [32]byte
rand.Read(key[:]) rand.Read(key[:])
@ -211,7 +211,7 @@ func BenchmarkDeobfs(b *testing.B) {
testPayload, testPayload,
} }
obfsBuf := make([]byte, defaultSendRecvBufSize) obfsBuf := make([]byte, len(testPayload)*2)
var key [32]byte var key [32]byte
rand.Read(key[:]) rand.Read(key[:])

View File

@ -25,4 +25,4 @@ type recvBuffer interface {
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks. // size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write, // If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
// a panic will happen. // a panic will happen.
const recvBufferSizeLimit = defaultSendRecvBufSize << 12 const recvBufferSizeLimit = 1<<31 - 1

View File

@ -14,9 +14,8 @@ import (
const ( const (
acceptBacklog = 1024 acceptBacklog = 1024
// TODO: will this be a signature?
defaultSendRecvBufSize = 20480
defaultInactivityTimeout = 30 * time.Second defaultInactivityTimeout = 30 * time.Second
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
) )
var ErrBrokenSession = errors.New("broken session") var ErrBrokenSession = errors.New("broken session")
@ -24,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
var errRepeatStreamClosing = errors.New("trying to close a closed stream") var errRepeatStreamClosing = errors.New("trying to close a closed stream")
var errNoMultiplex = errors.New("a singleplexing session can have only one stream") var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
type switchboardStrategy int
type SessionConfig struct { type SessionConfig struct {
Obfuscator Obfuscator
@ -40,12 +37,6 @@ type SessionConfig struct {
// maximum size of an obfuscated frame, including headers and overhead // maximum size of an obfuscated frame, including headers and overhead
MsgOnWireSizeLimit int MsgOnWireSizeLimit int
// StreamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
StreamSendBufferSize int
// ConnReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
ConnReceiveBufferSize int
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself // InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
InactivityTimeout time.Duration InactivityTimeout time.Duration
} }
@ -82,11 +73,17 @@ type Session struct {
closed uint32 closed uint32
terminalMsg atomic.Value terminalMsgSetter sync.Once
terminalMsg string
// the max size passed to Write calls before it splits it into multiple frames // the max size passed to Write calls before it splits it into multiple frames
// i.e. the max size a piece of data can fit into a Frame.Payload // i.e. the max size a piece of data can fit into a Frame.Payload
maxStreamUnitWrite int maxStreamUnitWrite int
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
streamSendBufferSize int
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
// switchboard.deplex)
connReceiveBufferSize int
} }
func MakeSession(id uint32, config SessionConfig) *Session { func MakeSession(id uint32, config SessionConfig) *Session {
@ -103,23 +100,19 @@ func MakeSession(id uint32, config SessionConfig) *Session {
if config.Valve == nil { if config.Valve == nil {
sesh.Valve = UNLIMITED_VALVE sesh.Valve = UNLIMITED_VALVE
} }
if config.StreamSendBufferSize <= 0 {
sesh.StreamSendBufferSize = defaultSendRecvBufSize
}
if config.ConnReceiveBufferSize <= 0 {
sesh.ConnReceiveBufferSize = defaultSendRecvBufSize
}
if config.MsgOnWireSizeLimit <= 0 { if config.MsgOnWireSizeLimit <= 0 {
sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024 sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
} }
if config.InactivityTimeout == 0 { if config.InactivityTimeout == 0 {
sesh.InactivityTimeout = defaultInactivityTimeout sesh.InactivityTimeout = defaultInactivityTimeout
} }
// todo: validation. this must be smaller than StreamSendBufferSize
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.maxOverhead
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit
sesh.connReceiveBufferSize = 20480 // for backwards compatibility
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} { sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
b := make([]byte, sesh.StreamSendBufferSize) b := make([]byte, sesh.streamSendBufferSize)
return &b return &b
}} }}
@ -187,7 +180,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
} }
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error _ = s.recvBuf.Close() // recvBuf.Close should not return error
if active { if active {
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte) tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
@ -271,16 +264,13 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
} }
func (sesh *Session) SetTerminalMsg(msg string) { func (sesh *Session) SetTerminalMsg(msg string) {
sesh.terminalMsg.Store(msg) sesh.terminalMsgSetter.Do(func() {
sesh.terminalMsg = msg
})
} }
func (sesh *Session) TerminalMsg() string { func (sesh *Session) TerminalMsg() string {
msg := sesh.terminalMsg.Load() return sesh.terminalMsg
if msg != nil {
return msg.(string)
} else {
return ""
}
} }
func (sesh *Session) closeSession() error { func (sesh *Session) closeSession() error {
@ -292,14 +282,12 @@ func (sesh *Session) closeSession() error {
sesh.streamsM.Lock() sesh.streamsM.Lock()
close(sesh.acceptCh) close(sesh.acceptCh)
for id, stream := range sesh.streams { for id, stream := range sesh.streams {
if stream == nil { if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
continue _ = stream.recvBuf.Close() // will not block
}
atomic.StoreUint32(&stream.closed, 1)
_ = stream.getRecvBuf().Close() // will not block
delete(sesh.streams, id) delete(sesh.streams, id)
sesh.streamCountDecr() sesh.streamCountDecr()
} }
}
sesh.streamsM.Unlock() sesh.streamsM.Unlock()
return nil return nil
} }
@ -339,7 +327,7 @@ func (sesh *Session) Close() error {
if err != nil { if err != nil {
return err return err
} }
_, err = sesh.sb.send((*buf)[:i], new(uint32)) _, err = sesh.sb.send((*buf)[:i], new(net.Conn))
if err != nil { if err != nil {
return err return err
} }

View File

@ -534,7 +534,7 @@ func TestSession_timeoutAfter(t *testing.T) {
func BenchmarkRecvDataFromRemote(b *testing.B) { func BenchmarkRecvDataFromRemote(b *testing.B) {
testPayload := make([]byte, testPayloadLen) testPayload := make([]byte, testPayloadLen)
rand.Read(testPayload) rand.Read(testPayload)
f := &Frame{ f := Frame{
1, 1,
0, 0,
0, 0,
@ -544,12 +544,13 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic
for name, ep := range encryptionMethods { for name, ep := range encryptionMethods {
ep := ep ep := ep
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs { for seshType, seshConfig := range seshConfigs {
b.Run(seshType, func(b *testing.B) { b.Run(seshType, func(b *testing.B) {
f := f
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey) seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
@ -561,7 +562,7 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
binaryFrames := [maxIter][]byte{} binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ { for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen) obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.obfuscate(f, obfsBuf, 0) n, _ := sesh.obfuscate(&f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n] binaryFrames[i] = obfsBuf[:n]
f.Seq++ f.Seq++
} }

View File

@ -23,9 +23,8 @@ type Stream struct {
session *Session session *Session
allocIdempot sync.Once
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
// been read by the consumer through Read or WriteTo. Lazily allocated // been read by the consumer through Read or WriteTo.
recvBuf recvBuffer recvBuf recvBuffer
writingM sync.Mutex writingM sync.Mutex
@ -40,7 +39,7 @@ type Stream struct {
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets // recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
// so it won't have to use its priority queue to sort it. // so it won't have to use its priority queue to sort it.
// This is not used in unordered connection mode // This is not used in unordered connection mode
assignedConnId uint32 assignedConn net.Conn
readFromTimeout time.Duration readFromTimeout time.Duration
} }
@ -56,25 +55,20 @@ func makeStream(sesh *Session, id uint32) *Stream {
}, },
} }
if sesh.Unordered {
stream.recvBuf = NewDatagramBufferedPipe()
} else {
stream.recvBuf = NewStreamBuffer()
}
return stream return stream
} }
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) getRecvBuf() recvBuffer {
s.allocIdempot.Do(func() {
if s.session.Unordered {
s.recvBuf = NewDatagramBufferedPipe()
} else {
s.recvBuf = NewStreamBuffer()
}
})
return s.recvBuf
}
// receive a readily deobfuscated Frame so its payload can later be Read // receive a readily deobfuscated Frame so its payload can later be Read
func (s *Stream) recvFrame(frame *Frame) error { func (s *Stream) recvFrame(frame *Frame) error {
toBeClosed, err := s.getRecvBuf().Write(frame) toBeClosed, err := s.recvBuf.Write(frame)
if toBeClosed { if toBeClosed {
err = s.passiveClose() err = s.passiveClose()
if errors.Is(err, errRepeatStreamClosing) { if errors.Is(err, errRepeatStreamClosing) {
@ -93,7 +87,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
return 0, nil return 0, nil
} }
n, err = s.getRecvBuf().Read(buf) n, err = s.recvBuf.Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err) log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
@ -104,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
// WriteTo continuously write data Stream has received into the writer w. // WriteTo continuously write data Stream has received into the writer w.
func (s *Stream) WriteTo(w io.Writer) (int64, error) { func (s *Stream) WriteTo(w io.Writer) (int64, error) {
// will keep writing until the underlying buffer is closed // will keep writing until the underlying buffer is closed
n, err := s.getRecvBuf().WriteTo(w) n, err := s.recvBuf.WriteTo(w)
log.Tracef("%v read from stream %v with err %v", n, s.id, err) log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
@ -119,7 +113,7 @@ func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
return err return err
} }
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId) _, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error()) s.session.SetTerminalMsg(err.Error())
@ -215,8 +209,8 @@ func (s *Stream) Close() error {
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) } func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil } func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
var errNotImplemented = errors.New("Not implemented") var errNotImplemented = errors.New("Not implemented")

View File

@ -1,7 +1,7 @@
package multiplex package multiplex
import ( import (
"bytes" "github.com/stretchr/testify/assert"
"math/rand" "math/rand"
"testing" "testing"
"time" "time"
@ -13,49 +13,15 @@ func TestPipeRW(t *testing.T) {
pipe := NewStreamBufferedPipe() pipe := NewStreamBufferedPipe()
b := []byte{0x01, 0x02, 0x03} b := []byte{0x01, 0x02, 0x03}
n, err := pipe.Write(b) n, err := pipe.Write(b)
if n != len(b) { assert.NoError(t, err, "simple write")
t.Error( assert.Equal(t, len(b), n, "number of bytes written")
"For", "number of bytes written",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple write",
"expecting", "nil error",
"got", err,
)
return
}
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err = pipe.Read(b2) n, err = pipe.Read(b2)
if n != len(b) { assert.NoError(t, err, "simple read")
t.Error( assert.Equal(t, len(b), n, "number of bytes read")
"For", "number of bytes read",
"expecting", len(b),
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
}
assert.Equal(t, b, b2)
} }
func TestReadBlock(t *testing.T) { func TestReadBlock(t *testing.T) {
@ -67,30 +33,10 @@ func TestReadBlock(t *testing.T) {
}() }()
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { assert.NoError(t, err, "blocked read")
t.Error( assert.Equal(t, len(b), n, "number of bytes read after block")
"For", "number of bytes read after block",
"expecting", len(b), assert.Equal(t, b, b2)
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "blocked read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "blocked read",
"expecting", b,
"got", b2,
)
return
}
} }
func TestPartialRead(t *testing.T) { func TestPartialRead(t *testing.T) {
@ -99,54 +45,17 @@ func TestPartialRead(t *testing.T) {
pipe.Write(b) pipe.Write(b)
b1 := make([]byte, 1) b1 := make([]byte, 1)
n, err := pipe.Read(b1) n, err := pipe.Read(b1)
if n != len(b1) { assert.NoError(t, err, "partial read of 1")
t.Error( assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
"For", "number of bytes in partial read of 1",
"expecting", len(b1), assert.Equal(t, b[0], b1[0])
"got", n,
)
return
}
if err != nil {
t.Error(
"For", "partial read of 1",
"expecting", "nil error",
"got", err,
)
return
}
if b1[0] != b[0] {
t.Error(
"For", "partial read of 1",
"expecting", b[0],
"got", b1[0],
)
}
b2 := make([]byte, 2) b2 := make([]byte, 2)
n, err = pipe.Read(b2) n, err = pipe.Read(b2)
if n != len(b2) { assert.NoError(t, err, "partial read of 2")
t.Error( assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
"For", "number of bytes in partial read of 2",
"expecting", len(b2), assert.Equal(t, b[1:], b2)
"got", n,
)
}
if err != nil {
t.Error(
"For", "partial read of 2",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b[1:], b2) {
t.Error(
"For", "partial read of 2",
"expecting", b[1:],
"got", b2,
)
return
}
} }
func TestReadAfterClose(t *testing.T) { func TestReadAfterClose(t *testing.T) {
@ -156,29 +65,10 @@ func TestReadAfterClose(t *testing.T) {
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
pipe.Close() pipe.Close()
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { assert.NoError(t, err, "simple read")
t.Error( assert.Equal(t, len(b), n, "number of bytes read")
"For", "number of bytes read",
"expecting", len(b), assert.Equal(t, b, b2)
"got", n,
)
}
if err != nil {
t.Error(
"For", "simple read",
"expecting", "nil error",
"got", err,
)
return
}
if !bytes.Equal(b, b2) {
t.Error(
"For", "simple read",
"expecting", b,
"got", b2,
)
return
}
} }
func BenchmarkBufferedPipe_RW(b *testing.B) { func BenchmarkBufferedPipe_RW(b *testing.B) {

View File

@ -10,9 +10,11 @@ import (
"time" "time"
) )
type switchboardStrategy int
const ( const (
FIXED_CONN_MAPPING switchboardStrategy = iota fixedConnMapping switchboardStrategy = iota
UNIFORM_SPREAD uniformSpread
) )
// switchboard represents the connection pool. It is responsible for managing // switchboard represents the connection pool. It is responsible for managing
@ -28,10 +30,8 @@ type switchboard struct {
valve Valve valve Valve
strategy switchboardStrategy strategy switchboardStrategy
// map of connId to net.Conn
conns sync.Map conns sync.Map
numConns uint32 connsCount uint32
nextConnId uint32
randPool sync.Pool randPool sync.Pool
broken uint32 broken uint32
@ -41,15 +41,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
var strategy switchboardStrategy var strategy switchboardStrategy
if sesh.Unordered { if sesh.Unordered {
log.Debug("Connection is unordered") log.Debug("Connection is unordered")
strategy = UNIFORM_SPREAD strategy = uniformSpread
} else { } else {
strategy = FIXED_CONN_MAPPING strategy = fixedConnMapping
} }
sb := &switchboard{ sb := &switchboard{
session: sesh, session: sesh,
strategy: strategy, strategy: strategy,
valve: sesh.Valve, valve: sesh.Valve,
nextConnId: 1,
randPool: sync.Pool{New: func() interface{} { randPool: sync.Pool{New: func() interface{} {
return rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
}}, }},
@ -59,88 +58,85 @@ func makeSwitchboard(sesh *Session) *switchboard {
var errBrokenSwitchboard = errors.New("the switchboard is broken") var errBrokenSwitchboard = errors.New("the switchboard is broken")
func (sb *switchboard) connsCount() int {
return int(atomic.LoadUint32(&sb.numConns))
}
func (sb *switchboard) addConn(conn net.Conn) { func (sb *switchboard) addConn(conn net.Conn) {
connId := atomic.AddUint32(&sb.nextConnId, 1) - 1 atomic.AddUint32(&sb.connsCount, 1)
atomic.AddUint32(&sb.numConns, 1) sb.conns.Store(conn, conn)
sb.conns.Store(connId, conn) go sb.deplex(conn)
go sb.deplex(connId, conn)
} }
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable // a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) { func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) {
sb.valve.txWait(len(data)) sb.valve.txWait(len(data))
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 { if atomic.LoadUint32(&sb.broken) == 1 {
return 0, errBrokenSwitchboard return 0, errBrokenSwitchboard
} }
var conn net.Conn var conn net.Conn
switch sb.strategy { switch sb.strategy {
case UNIFORM_SPREAD: case uniformSpread:
_, conn, err = sb.pickRandConn() conn, err = sb.pickRandConn()
if err != nil { if err != nil {
return 0, errBrokenSwitchboard return 0, errBrokenSwitchboard
} }
case FIXED_CONN_MAPPING: n, err = conn.Write(data)
connI, ok := sb.conns.Load(*connId)
if ok {
conn = connI.(net.Conn)
} else {
var newConnId uint32
newConnId, conn, err = sb.pickRandConn()
if err != nil { if err != nil {
return 0, errBrokenSwitchboard sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
} }
*connId = newConnId case fixedConnMapping:
conn = *assignedConn
if conn == nil {
conn, err = sb.pickRandConn()
if err != nil {
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
sb.session.passiveClose()
return 0, err
}
*assignedConn = conn
}
n, err = conn.Write(data)
if err != nil {
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
sb.session.passiveClose()
return n, err
} }
default: default:
return 0, errors.New("unsupported traffic distribution strategy") return 0, errors.New("unsupported traffic distribution strategy")
} }
n, err = conn.Write(data)
if err != nil {
sb.conns.Delete(*connId)
sb.session.SetTerminalMsg("failed to write to remote " + err.Error())
sb.session.passiveClose()
return n, err
}
sb.valve.AddTx(int64(n)) sb.valve.AddTx(int64(n))
return n, nil return n, nil
} }
// returns a random connId // returns a random connId
func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) { func (sb *switchboard) pickRandConn() (net.Conn, error) {
connCount := sb.connsCount() if atomic.LoadUint32(&sb.broken) == 1 {
if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 { return nil, errBrokenSwitchboard
return 0, nil, errBrokenSwitchboard }
connsCount := atomic.LoadUint32(&sb.connsCount)
if connsCount == 0 {
return nil, errBrokenSwitchboard
} }
// there is no guarantee that sb.conns still has the same amount of entries
// between the count loop and the pick loop
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked
var id uint32
var conn net.Conn
randReader := sb.randPool.Get().(*rand.Rand) randReader := sb.randPool.Get().(*rand.Rand)
r := randReader.Intn(connCount)
r := randReader.Intn(int(connsCount))
sb.randPool.Put(randReader) sb.randPool.Put(randReader)
var c int var c int
sb.conns.Range(func(connIdI, connI interface{}) bool { var ret net.Conn
sb.conns.Range(func(_, conn interface{}) bool {
if r == c { if r == c {
id = connIdI.(uint32) ret = conn.(net.Conn)
conn = connI.(net.Conn)
return false return false
} }
c++ c++
return true return true
}) })
// if len(sb.conns) is 0
if conn == nil { return ret, nil
return 0, nil, errBrokenSwitchboard
}
return id, conn, nil
} }
// actively triggered by session.Close() // actively triggered by session.Close()
@ -148,26 +144,24 @@ func (sb *switchboard) closeAll() {
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) { if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
return return
} }
sb.conns.Range(func(key, connI interface{}) bool { sb.conns.Range(func(_, conn interface{}) bool {
conn := connI.(net.Conn) conn.(net.Conn).Close()
conn.Close() sb.conns.Delete(conn)
sb.conns.Delete(key) atomic.AddUint32(&sb.connsCount, ^uint32(0))
return true return true
}) })
} }
// deplex function costantly reads from a TCP connection // deplex function costantly reads from a TCP connection
func (sb *switchboard) deplex(connId uint32, conn net.Conn) { func (sb *switchboard) deplex(conn net.Conn) {
defer conn.Close() defer conn.Close()
buf := make([]byte, sb.session.ConnReceiveBufferSize) buf := make([]byte, sb.session.connReceiveBufferSize)
for { for {
n, err := conn.Read(buf) n, err := conn.Read(buf)
sb.valve.rxWait(n) sb.valve.rxWait(n)
sb.valve.AddRx(int64(n)) sb.valve.AddRx(int64(n))
if err != nil { if err != nil {
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err) log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
sb.conns.Delete(connId)
atomic.AddUint32(&sb.numConns, ^uint32(0))
sb.session.SetTerminalMsg("a connection has dropped unexpectedly") sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
sb.session.passiveClose() sb.session.passiveClose()
return return

View File

@ -5,6 +5,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"math/rand" "math/rand"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
@ -14,14 +15,14 @@ func TestSwitchboard_Send(t *testing.T) {
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
hole0 := connutil.Discard() hole0 := connutil.Discard()
sesh.sb.addConn(hole0) sesh.sb.addConn(hole0)
connId, _, err := sesh.sb.pickRandConn() conn, err := sesh.sb.pickRandConn()
if err != nil { if err != nil {
t.Error("failed to get a random conn", err) t.Error("failed to get a random conn", err)
return return
} }
data := make([]byte, 1000) data := make([]byte, 1000)
rand.Read(data) rand.Read(data)
_, err = sesh.sb.send(data, &connId) _, err = sesh.sb.send(data, &conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -29,23 +30,23 @@ func TestSwitchboard_Send(t *testing.T) {
hole1 := connutil.Discard() hole1 := connutil.Discard()
sesh.sb.addConn(hole1) sesh.sb.addConn(hole1)
connId, _, err = sesh.sb.pickRandConn() conn, err = sesh.sb.pickRandConn()
if err != nil { if err != nil {
t.Error("failed to get a random conn", err) t.Error("failed to get a random conn", err)
return return
} }
_, err = sesh.sb.send(data, &connId) _, err = sesh.sb.send(data, &conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
connId, _, err = sesh.sb.pickRandConn() conn, err = sesh.sb.pickRandConn()
if err != nil { if err != nil {
t.Error("failed to get a random conn", err) t.Error("failed to get a random conn", err)
return return
} }
_, err = sesh.sb.send(data, &connId) _, err = sesh.sb.send(data, &conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -71,7 +72,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
seshConfig := SessionConfig{} seshConfig := SessionConfig{}
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
sesh.sb.addConn(hole) sesh.sb.addConn(hole)
connId, _, err := sesh.sb.pickRandConn() conn, err := sesh.sb.pickRandConn()
if err != nil { if err != nil {
b.Error("failed to get a random conn", err) b.Error("failed to get a random conn", err)
return return
@ -81,7 +82,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
b.SetBytes(int64(len(data))) b.SetBytes(int64(len(data)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
sesh.sb.send(data, &connId) sesh.sb.send(data, &conn)
} }
} }
@ -92,7 +93,7 @@ func TestSwitchboard_TxCredit(t *testing.T) {
sesh := MakeSession(0, seshConfig) sesh := MakeSession(0, seshConfig)
hole := connutil.Discard() hole := connutil.Discard()
sesh.sb.addConn(hole) sesh.sb.addConn(hole)
connId, _, err := sesh.sb.pickRandConn() conn, err := sesh.sb.pickRandConn()
if err != nil { if err != nil {
t.Error("failed to get a random conn", err) t.Error("failed to get a random conn", err)
return return
@ -100,10 +101,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
data := make([]byte, 1000) data := make([]byte, 1000)
rand.Read(data) rand.Read(data)
t.Run("FIXED CONN MAPPING", func(t *testing.T) { t.Run("fixed conn mapping", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0 *sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = FIXED_CONN_MAPPING sesh.sb.strategy = fixedConnMapping
n, err := sesh.sb.send(data[:10], &connId) n, err := sesh.sb.send(data[:10], &conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -116,10 +117,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
t.Error("tx credit didn't increase by 10") t.Error("tx credit didn't increase by 10")
} }
}) })
t.Run("UNIFORM", func(t *testing.T) { t.Run("uniform spread", func(t *testing.T) {
*sesh.sb.valve.(*LimitedValve).tx = 0 *sesh.sb.valve.(*LimitedValve).tx = 0
sesh.sb.strategy = UNIFORM_SPREAD sesh.sb.strategy = uniformSpread
n, err := sesh.sb.send(data[:10], &connId) n, err := sesh.sb.send(data[:10], &conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -173,13 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
} }
wg.Wait() wg.Wait()
if sesh.sb.connsCount() != 1000 { if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
t.Error("connsCount incorrect") t.Error("connsCount incorrect")
} }
sesh.sb.closeAll() sesh.sb.closeAll()
assert.Eventuallyf(t, func() bool { assert.Eventuallyf(t, func() bool {
return sesh.sb.connsCount() == 0 return atomic.LoadUint32(&sesh.sb.connsCount) == 0
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", sesh.sb.connsCount()) }, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
} }

View File

@ -2,7 +2,7 @@ swagger: '2.0'
info: info:
description: | description: |
This is the API of Cloak server This is the API of Cloak server
version: 1.0.0 version: 0.0.2
title: Cloak Server title: Cloak Server
contact: contact:
email: cbeuw.andy@gmail.com email: cbeuw.andy@gmail.com
@ -12,8 +12,6 @@ info:
# host: petstore.swagger.io # host: petstore.swagger.io
# basePath: /v2 # basePath: /v2
tags: tags:
- name: admin
description: Endpoints used by the host administrators
- name: users - name: users
description: Operations related to user controls by admin description: Operations related to user controls by admin
# schemes: # schemes:
@ -22,7 +20,6 @@ paths:
/admin/users: /admin/users:
get: get:
tags: tags:
- admin
- users - users
summary: Show all users summary: Show all users
description: Returns an array of all UserInfo description: Returns an array of all UserInfo
@ -41,7 +38,6 @@ paths:
/admin/users/{UID}: /admin/users/{UID}:
get: get:
tags: tags:
- admin
- users - users
summary: Show userinfo by UID summary: Show userinfo by UID
description: Returns a UserInfo object description: Returns a UserInfo object
@ -68,7 +64,6 @@ paths:
description: internal error description: internal error
post: post:
tags: tags:
- admin
- users - users
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
operationId: writeUserInfo operationId: writeUserInfo
@ -100,7 +95,6 @@ paths:
description: internal error description: internal error
delete: delete:
tags: tags:
- admin
- users - users
summary: Deletes a user summary: Deletes a user
operationId: deleteUser operationId: deleteUser

View File

@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) {
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body) assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
}) })
t.Run("partial update", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
assert.NoError(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
partialUserInfo := UserInfo{
UID: mockUID,
SessionsCap: JustInt32(10),
}
partialMarshalled, _ := json.Marshal(partialUserInfo)
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
assert.NoError(t, err)
router.ServeHTTP(rr, req)
assert.Equal(t, http.StatusCreated, rr.Code)
var got UserInfo
err = json.Unmarshal(rr.Body.Bytes(), &got)
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = partialUserInfo.SessionsCap
assert.EqualValues(t, expected, got)
})
t.Run("empty parameter", func(t *testing.T) { t.Run("empty parameter", func(t *testing.T) {
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled)) req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
if err != nil { if err != nil {

View File

@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
"User no longer exists", "User no longer exists",
} }
responses = append(responses, resp) responses = append(responses, resp)
continue
} }
oldUp := int64(u64(bucket.Get([]byte("UpCredit")))) oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error { err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
var uinfo UserInfo var uinfo UserInfo
uinfo.UID = UID uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
infos = append(infos, uinfo) infos = append(infos, uinfo)
return nil return nil
}) })
@ -200,41 +201,53 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error)
return ErrUserNotFound return ErrUserNotFound
} }
uinfo.UID = UID uinfo.UID = UID
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap")))) uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate")))) uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate")))) uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit")))) uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit")))) uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime")))) uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
return nil return nil
}) })
return return
} }
func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) { func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
err = manager.db.Update(func(tx *bolt.Tx) error { err = manager.db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID) bucket, err := tx.CreateBucketIfNotExists(u.UID)
if err != nil { if err != nil {
return err return err
} }
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil { if u.SessionsCap != nil {
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
return err return err
} }
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil { }
if u.UpRate != nil {
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
return err return err
} }
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil { }
if u.DownRate != nil {
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
return err return err
} }
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil { }
if u.UpCredit != nil {
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
return err return err
} }
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil { }
if u.DownCredit != nil {
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
return err return err
} }
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil { }
if u.ExpiryTime != nil {
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
return err return err
} }
}
return nil return nil
}) })
return return

View File

@ -3,6 +3,7 @@ package usermanager
import ( import (
"encoding/binary" "encoding/binary"
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"os" "os"
@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var mockUserInfo = UserInfo{ var mockUserInfo = UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 0, SessionsCap: JustInt32(10),
UpRate: 0, UpRate: JustInt64(100),
DownRate: 0, DownRate: JustInt64(1000),
UpCredit: 0, UpCredit: JustInt64(10000),
DownCredit: 0, DownCredit: JustInt64(100000),
ExpiryTime: 100, ExpiryTime: JustInt64(1000000),
} }
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) { func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
got, err := mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, mockUserInfo, got)
/* Partial update */
err = mgr.WriteUserInfo(UserInfo{
UID: mockUID,
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
})
assert.NoError(t, err)
expected := mockUserInfo
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
got, err = mgr.GetUserInfo(mockUID)
assert.NoError(t, err)
assert.EqualValues(t, expected, got)
} }
func TestLocalManager_GetUserInfo(t *testing.T) { func TestLocalManager_GetUserInfo(t *testing.T) {
@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) {
t.Run("update a field", func(t *testing.T) { t.Run("update a field", func(t *testing.T) {
_ = mgr.WriteUserInfo(mockUserInfo) _ = mgr.WriteUserInfo(mockUserInfo)
updatedUserInfo := mockUserInfo updatedUserInfo := mockUserInfo
updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1 updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
err := mgr.WriteUserInfo(updatedUserInfo) err := mgr.WriteUserInfo(updatedUserInfo)
if err != nil { if err != nil {
@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) {
} }
} }
var validUserInfo = UserInfo{ var validUserInfo = mockUserInfo
UID: mockUID,
SessionsCap: 10,
UpRate: 100,
DownRate: 1000,
UpCredit: 10000,
DownCredit: 100000,
ExpiryTime: 1000000,
}
func TestLocalManager_AuthenticateUser(t *testing.T) { func TestLocalManager_AuthenticateUser(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Error(err) t.Error(err)
} }
if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate { if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
t.Error("wrong up or down rate") t.Error("wrong up or down rate")
} }
}) })
@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("expired user", func(t *testing.T) { t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo) _ = mgr.WriteUserInfo(expiredUserInfo)
@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
t.Run("no credit", func(t *testing.T) { t.Run("no credit", func(t *testing.T) {
creditlessUserInfo := validUserInfo creditlessUserInfo := validUserInfo
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1 creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
_ = mgr.WriteUserInfo(creditlessUserInfo) _ = mgr.WriteUserInfo(creditlessUserInfo)
@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("expired user", func(t *testing.T) { t.Run("expired user", func(t *testing.T) {
expiredUserInfo := validUserInfo expiredUserInfo := validUserInfo
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix() expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
_ = mgr.WriteUserInfo(expiredUserInfo) _ = mgr.WriteUserInfo(expiredUserInfo)
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0}) err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
t.Run("too many sessions", func(t *testing.T) { t.Run("too many sessions", func(t *testing.T) {
_ = mgr.WriteUserInfo(validUserInfo) _ = mgr.WriteUserInfo(validUserInfo)
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)}) err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
if err != ErrSessionsCapReached { if err != ErrSessionsCapReached {
t.Error("session cap not reached") t.Error("session cap not reached")
} }
@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) {
t.Error(err) t.Error(err)
} }
if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage { if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
t.Error("up usage incorrect") t.Error("up usage incorrect")
} }
if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage { if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
t.Error("down usage incorrect") t.Error("down usage incorrect")
} }
}) })
@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) {
UID: validUserInfo.UID, UID: validUserInfo.UID,
Active: true, Active: true,
NumSession: 1, NumSession: 1,
UpUsage: validUserInfo.UpCredit + 100, UpUsage: *validUserInfo.UpCredit + 100,
DownUsage: 0, DownUsage: 0,
Timestamp: mockWorldState.Now().Unix(), Timestamp: mockWorldState.Now().Unix(),
}, },
@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) {
Active: true, Active: true,
NumSession: 1, NumSession: 1,
UpUsage: 0, UpUsage: 0,
DownUsage: validUserInfo.DownCredit + 100, DownUsage: *validUserInfo.DownCredit + 100,
Timestamp: mockWorldState.Now().Unix(), Timestamp: mockWorldState.Now().Unix(),
}, },
}, },
{"expired", {"expired",
UserInfo{ UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 10, SessionsCap: JustInt32(10),
UpRate: 0, UpRate: JustInt64(0),
DownRate: 0, DownRate: JustInt64(0),
UpCredit: 0, UpCredit: JustInt64(0),
DownCredit: 0, DownCredit: JustInt64(0),
ExpiryTime: -1, ExpiryTime: JustInt64(-1),
}, },
StatusUpdate{ StatusUpdate{
UID: mockUserInfo.UID, UID: mockUserInfo.UID,
@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) {
rand.Read(randUID) rand.Read(randUID)
newUser := UserInfo{ newUser := UserInfo{
UID: randUID, UID: randUID,
SessionsCap: rand.Int31(), SessionsCap: JustInt32(rand.Int31()),
UpRate: rand.Int63(), UpRate: JustInt64(rand.Int63()),
DownRate: rand.Int63(), DownRate: JustInt64(rand.Int63()),
UpCredit: rand.Int63(), UpCredit: JustInt64(rand.Int63()),
DownCredit: rand.Int63(), DownCredit: JustInt64(rand.Int63()),
ExpiryTime: rand.Int63(), ExpiryTime: JustInt64(rand.Int63()),
} }
users = append(users, newUser) users = append(users, newUser)
wg.Add(1) wg.Add(1)

View File

@ -14,16 +14,23 @@ type StatusUpdate struct {
Timestamp int64 Timestamp int64
} }
type MaybeInt32 *int32
type MaybeInt64 *int64
type UserInfo struct { type UserInfo struct {
UID []byte UID []byte
SessionsCap int32 SessionsCap MaybeInt32
UpRate int64 UpRate MaybeInt64
DownRate int64 DownRate MaybeInt64
UpCredit int64 UpCredit MaybeInt64
DownCredit int64 DownCredit MaybeInt64
ExpiryTime int64 ExpiryTime MaybeInt64
} }
func JustInt32(v int32) MaybeInt32 { return &v }
func JustInt64(v int64) MaybeInt64 { return &v }
type StatusResponse struct { type StatusResponse struct {
UID []byte UID []byte
Action int Action int

View File

@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var mockWorldState = common.WorldOfTime(time.Unix(1, 0)) var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
var validUserInfo = usermanager.UserInfo{ var validUserInfo = usermanager.UserInfo{
UID: mockUID, UID: mockUID,
SessionsCap: 10, SessionsCap: usermanager.JustInt32(10),
UpRate: 100, UpRate: usermanager.JustInt64(100),
DownRate: 1000, DownRate: usermanager.JustInt64(1000),
UpCredit: 10000, UpCredit: usermanager.JustInt64(10000),
DownCredit: 100000, DownCredit: usermanager.JustInt64(100000),
ExpiryTime: 1000000, ExpiryTime: usermanager.JustInt64(1000000),
} }
func TestUserPanel_GetUser(t *testing.T) { func TestUserPanel_GetUser(t *testing.T) {
@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 { if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
t.Error("down credit incorrect update") t.Error("down credit incorrect update")
} }
if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 { if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
t.Error("up credit incorrect update") t.Error("up credit incorrect update")
} }
@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) { if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
t.Error("down credit incorrect update") t.Error("down credit incorrect update")
} }
if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) { if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
t.Error("up credit incorrect update") t.Error("up credit incorrect update")
} }
}) })
@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
t.Error(err) t.Error(err)
} }
user.valve.AddTx(validUserInfo.DownCredit + 100) user.valve.AddTx(*validUserInfo.DownCredit + 100)
panel.updateUsageQueue() panel.updateUsageQueue()
err = panel.commitUpdate() err = panel.commitUpdate()
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
} }
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID) updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
if updatedUinfo.DownCredit != -100 { if *updatedUinfo.DownCredit != -100 {
t.Error("down credit not updated correctly after the user has been terminated") t.Error("down credit not updated correctly after the user has been terminated")
} }
}) })

View File

@ -321,7 +321,7 @@ func TestTCPSingleplex(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
const echoMsgLen = 16384 const echoMsgLen = 1 << 16
go serveTCPEcho(proxyFromCkServerL) go serveTCPEcho(proxyFromCkServerL)
proxyConn1, err := proxyToCkClientD.Dial("", "") proxyConn1, err := proxyToCkClientD.Dial("", "")

View File

@ -12,7 +12,7 @@ if [ -z "$v" ]; then
fi fi
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v" output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
osarch="!darwin/arm !darwin/arm64 !darwin/386" osarch="!darwin/arm !darwin/386"
echo "Compiling:" echo "Compiling:"