mirror of https://github.com/cbeuw/Cloak
Merge branch 'master' into notsure2
This commit is contained in:
commit
ba94c594f0
13
README.md
13
README.md
|
|
@ -135,6 +135,19 @@ encryption and authentication (via AEAD or similar techniques).**
|
||||||
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
|
`ServerName` is the domain you want to make your ISP or firewall _think_ you are visiting. Ideally it should
|
||||||
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
|
match `RedirAddr` in the server's configuration, a major site the censor allows, but it doesn't have to.
|
||||||
|
|
||||||
|
`AlternativeNames` is an array used alongside `ServerName` to shuffle between different ServerNames for every new
|
||||||
|
connection. **This may conflict with `CDN` Transport mode** if the CDN provider prohibits domain fronting and rejects
|
||||||
|
the alternative domains.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ServerName": "bing.com",
|
||||||
|
"AlternativeNames": ["cloudflare.com", "github.com"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
|
`CDNOriginHost` is the domain name of the _origin_ server (i.e. the server running Cloak) under `CDN` mode. This only
|
||||||
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
|
has effect when `Transport` is set to `CDN`. If unset, it will default to the remote hostname supplied via the
|
||||||
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with
|
commandline argument (in standalone mode), or by Shadowsocks (in plugin mode). After a TLS session is established with
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,11 @@ func main() {
|
||||||
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
|
log.Infof("Listening on %v %v for %v client", network, localConfig.LocalAddr, authInfo.ProxyMethod)
|
||||||
seshMaker = func() *mux.Session {
|
seshMaker = func() *mux.Session {
|
||||||
authInfo := authInfo // copy the struct because we are overwriting SessionId
|
authInfo := authInfo // copy the struct because we are overwriting SessionId
|
||||||
|
|
||||||
|
randByte := make([]byte, 1)
|
||||||
|
common.RandRead(authInfo.WorldState.Rand, randByte)
|
||||||
|
authInfo.MockDomain = localConfig.MockDomainList[int(randByte[0])%len(localConfig.MockDomainList)]
|
||||||
|
|
||||||
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
|
// sessionID is usergenerated. There shouldn't be a security concern because the scope of
|
||||||
// sessionID is limited to its UID.
|
// sessionID is limited to its UID.
|
||||||
quad := make([]byte, 4)
|
quad := make([]byte, 4)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
// +build android
|
// +build android
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go
|
// Stolen from https://github.com/shadowsocks/overture/blob/shadowsocks/core/utils/utils_android.go
|
||||||
|
|
|
||||||
|
|
@ -9,49 +9,27 @@ import (
|
||||||
func TestParseBindAddr(t *testing.T) {
|
func TestParseBindAddr(t *testing.T) {
|
||||||
t.Run("port only", func(t *testing.T) {
|
t.Run("port only", func(t *testing.T) {
|
||||||
addrs, err := resolveBindAddr([]string{":443"})
|
addrs, err := resolveBindAddr([]string{":443"})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Error(err)
|
assert.Equal(t, ":443", addrs[0].String())
|
||||||
return
|
|
||||||
}
|
|
||||||
if addrs[0].String() != ":443" {
|
|
||||||
t.Errorf("expected %v got %v", ":443", addrs[0].String())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("specific address", func(t *testing.T) {
|
t.Run("specific address", func(t *testing.T) {
|
||||||
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
|
addrs, err := resolveBindAddr([]string{"192.168.1.123:443"})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Error(err)
|
assert.Equal(t, "192.168.1.123:443", addrs[0].String())
|
||||||
return
|
|
||||||
}
|
|
||||||
if addrs[0].String() != "192.168.1.123:443" {
|
|
||||||
t.Errorf("expected %v got %v", "192.168.1.123:443", addrs[0].String())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ipv6", func(t *testing.T) {
|
t.Run("ipv6", func(t *testing.T) {
|
||||||
addrs, err := resolveBindAddr([]string{"[::]:443"})
|
addrs, err := resolveBindAddr([]string{"[::]:443"})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Error(err)
|
assert.Equal(t, "[::]:443", addrs[0].String())
|
||||||
return
|
|
||||||
}
|
|
||||||
if addrs[0].String() != "[::]:443" {
|
|
||||||
t.Errorf("expected %v got %v", "[::]:443", addrs[0].String())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("mixed", func(t *testing.T) {
|
t.Run("mixed", func(t *testing.T) {
|
||||||
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
|
addrs, err := resolveBindAddr([]string{":80", "[::]:443"})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Error(err)
|
assert.Equal(t, ":80", addrs[0].String())
|
||||||
return
|
assert.Equal(t, "[::]:443", addrs[1].String())
|
||||||
}
|
|
||||||
if addrs[0].String() != ":80" {
|
|
||||||
t.Errorf("expected %v got %v", ":80", addrs[0].String())
|
|
||||||
}
|
|
||||||
if addrs[1].String() != "[::]:443" {
|
|
||||||
t.Errorf("expected %v got %v", "[::]:443", addrs[1].String())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -33,11 +33,6 @@ func TestMakeServerName(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range pairs {
|
for _, p := range pairs {
|
||||||
if !bytes.Equal(makeServerName(p.serverName), p.target) {
|
assert.Equal(t, p.target, makeServerName(p.serverName))
|
||||||
t.Error(
|
|
||||||
"for", p.serverName,
|
|
||||||
"expecting", p.target,
|
|
||||||
"got", makeServerName(p.serverName))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/cbeuw/Cloak/internal/common"
|
"github.com/cbeuw/Cloak/internal/common"
|
||||||
"github.com/cbeuw/Cloak/internal/multiplex"
|
"github.com/cbeuw/Cloak/internal/multiplex"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -64,12 +65,8 @@ func TestMakeAuthenticationPayload(t *testing.T) {
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
func() {
|
func() {
|
||||||
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
|
payload, sharedSecret := makeAuthenticationPayload(tc.authInfo)
|
||||||
if payload != tc.expPayload {
|
assert.Equal(t, tc.expPayload, payload, "payload doesn't match")
|
||||||
t.Errorf("payload doesn't match:\nexp %v\ngot %v", tc.expPayload, payload)
|
assert.Equal(t, tc.expSecret, sharedSecret, "shared secret doesn't match")
|
||||||
}
|
|
||||||
if sharedSecret != tc.expSecret {
|
|
||||||
t.Errorf("secret doesn't match:\nexp %x\ngot %x", tc.expPayload, payload)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ type RawConfig struct {
|
||||||
LocalPort string // jsonOptional
|
LocalPort string // jsonOptional
|
||||||
RemoteHost string // jsonOptional
|
RemoteHost string // jsonOptional
|
||||||
RemotePort string // jsonOptional
|
RemotePort string // jsonOptional
|
||||||
|
AlternativeNames []string // jsonOptional
|
||||||
// defaults set in ProcessRawConfig
|
// defaults set in ProcessRawConfig
|
||||||
UDP bool // nullable
|
UDP bool // nullable
|
||||||
BrowserSig string // nullable
|
BrowserSig string // nullable
|
||||||
|
|
@ -51,6 +51,7 @@ type RemoteConnConfig struct {
|
||||||
type LocalConnConfig struct {
|
type LocalConnConfig struct {
|
||||||
LocalAddr string
|
LocalAddr string
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
MockDomainList []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthInfo struct {
|
type AuthInfo struct {
|
||||||
|
|
@ -94,6 +95,20 @@ func ssvToJson(ssv string) (ret []byte) {
|
||||||
}
|
}
|
||||||
key := sp[0]
|
key := sp[0]
|
||||||
value := sp[1]
|
value := sp[1]
|
||||||
|
if strings.HasPrefix(key, "AlternativeNames") {
|
||||||
|
switch strings.Contains(value, ",") {
|
||||||
|
case true:
|
||||||
|
domains := strings.Split(value, ",")
|
||||||
|
for index, domain := range domains {
|
||||||
|
domains[index] = `"` + domain + `"`
|
||||||
|
}
|
||||||
|
value = strings.Join(domains, ",")
|
||||||
|
ret = append(ret, []byte(`"`+key+`":[`+value+`],`)...)
|
||||||
|
case false:
|
||||||
|
ret = append(ret, []byte(`"`+key+`":["`+value+`"],`)...)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
// JSON doesn't like quotation marks around int and bool
|
// JSON doesn't like quotation marks around int and bool
|
||||||
// This is extremely ugly but it's still better than writing a tokeniser
|
// This is extremely ugly but it's still better than writing a tokeniser
|
||||||
if elem(key, unquoted) {
|
if elem(key, unquoted) {
|
||||||
|
|
@ -139,6 +154,8 @@ func (raw *RawConfig) ProcessRawConfig(worldState common.WorldState) (local Loca
|
||||||
return nullErr("ServerName")
|
return nullErr("ServerName")
|
||||||
}
|
}
|
||||||
auth.MockDomain = raw.ServerName
|
auth.MockDomain = raw.ServerName
|
||||||
|
local.MockDomainList = raw.AlternativeNames
|
||||||
|
local.MockDomainList = append(local.MockDomainList, auth.MockDomain)
|
||||||
if raw.ProxyMethod == "" {
|
if raw.ProxyMethod == "" {
|
||||||
return nullErr("ServerName")
|
return nullErr("ServerName")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
@ -94,6 +95,9 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) {
|
||||||
|
|
||||||
func (tls *TLSConn) Write(in []byte) (n int, err error) {
|
func (tls *TLSConn) Write(in []byte) (n int, err error) {
|
||||||
msgLen := len(in)
|
msgLen := len(in)
|
||||||
|
if msgLen > 1<<14+256 { // https://tools.ietf.org/html/rfc8446#section-5.2
|
||||||
|
return 0, errors.New("message is too long")
|
||||||
|
}
|
||||||
writeBuf := tls.writeBufPool.Get().(*[]byte)
|
writeBuf := tls.writeBufPool.Get().(*[]byte)
|
||||||
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
|
*writeBuf = append(*writeBuf, byte(msgLen>>8), byte(msgLen&0xFF))
|
||||||
*writeBuf = append(*writeBuf, in...)
|
*writeBuf = append(*writeBuf, in...)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package multiplex
|
package multiplex
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -11,13 +11,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
t.Run("simple write", func(t *testing.T) {
|
t.Run("simple write", func(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
_, err := pipe.Write(&Frame{Payload: b})
|
_, err := pipe.Write(&Frame{Payload: b})
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Error(
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("simple read", func(t *testing.T) {
|
t.Run("simple read", func(t *testing.T) {
|
||||||
|
|
@ -25,50 +19,18 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
_, _ = pipe.Write(&Frame{Payload: b})
|
_, _ = pipe.Write(&Frame{Payload: b})
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err)
|
||||||
t.Error(
|
assert.Equal(t, len(b), n)
|
||||||
"For", "number of bytes read",
|
assert.Equal(t, b, b2)
|
||||||
"expecting", len(b),
|
assert.Equal(t, 0, pipe.buf.Len(), "buf len is not 0 after finished reading")
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if pipe.buf.Len() != 0 {
|
|
||||||
t.Error("buf len is not 0 after finished reading")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("writing closing frame", func(t *testing.T) {
|
t.Run("writing closing frame", func(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
|
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
|
||||||
if !toBeClosed {
|
assert.NoError(t, err)
|
||||||
t.Error("should be to be closed")
|
assert.True(t, toBeClosed, "should be to be closed")
|
||||||
}
|
assert.True(t, pipe.closed, "pipe should be closed")
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !pipe.closed {
|
|
||||||
t.Error("expecting closed pipe, not closed")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -81,30 +43,9 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err)
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
|
||||||
"For", "number of bytes read after block",
|
assert.Equal(t, b, b2)
|
||||||
"expecting", len(b),
|
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "blocked read",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "blocked read",
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
||||||
|
|
@ -114,27 +55,7 @@ func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
pipe.Close()
|
pipe.Close()
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err)
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes read after block is wrong")
|
||||||
"For", "number of bytes read",
|
assert.Equal(t, b, b2)
|
||||||
"expecting", len(b),
|
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -108,9 +108,7 @@ func TestMultiplex(t *testing.T) {
|
||||||
streams := make([]net.Conn, numStreams)
|
streams := make([]net.Conn, numStreams)
|
||||||
for i := 0; i < numStreams; i++ {
|
for i := 0; i < numStreams; i++ {
|
||||||
stream, err := clientSession.OpenStream()
|
stream, err := clientSession.OpenStream()
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Fatalf("failed to open stream: %v", err)
|
|
||||||
}
|
|
||||||
streams[i] = stream
|
streams[i] = stream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -123,18 +121,11 @@ func TestMultiplex(t *testing.T) {
|
||||||
// close one stream
|
// close one stream
|
||||||
closing, streams := streams[0], streams[1:]
|
closing, streams := streams[0], streams[1:]
|
||||||
err := closing.Close()
|
err := closing.Close()
|
||||||
if err != nil {
|
assert.NoError(t, err, "couldn't close a stream")
|
||||||
t.Errorf("couldn't close a stream")
|
|
||||||
}
|
|
||||||
_, err = closing.Write([]byte{0})
|
_, err = closing.Write([]byte{0})
|
||||||
if err != ErrBrokenStream {
|
assert.Equal(t, ErrBrokenStream, err)
|
||||||
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
|
|
||||||
}
|
|
||||||
_, err = closing.Read(make([]byte, 1))
|
_, err = closing.Read(make([]byte, 1))
|
||||||
if err != ErrBrokenStream {
|
assert.Equal(t, ErrBrokenStream, err)
|
||||||
t.Errorf("expecting error %v, got %v", ErrBrokenStream, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMux_StreamClosing(t *testing.T) {
|
func TestMux_StreamClosing(t *testing.T) {
|
||||||
|
|
@ -146,20 +137,13 @@ func TestMux_StreamClosing(t *testing.T) {
|
||||||
recvBuf := make([]byte, 128)
|
recvBuf := make([]byte, 128)
|
||||||
toBeClosed, _ := clientSession.OpenStream()
|
toBeClosed, _ := clientSession.OpenStream()
|
||||||
_, err := toBeClosed.Write(testData) // should be echoed back
|
_, err := toBeClosed.Write(testData) // should be echoed back
|
||||||
if err != nil {
|
assert.NoError(t, err, "couldn't write to a stream")
|
||||||
t.Errorf("can't write to stream: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
|
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
|
||||||
if err != nil {
|
assert.NoError(t, err, "can't read anything before stream closed")
|
||||||
t.Errorf("can't read anything before stream closed: %v", err)
|
|
||||||
}
|
|
||||||
_ = toBeClosed.Close()
|
_ = toBeClosed.Close()
|
||||||
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
|
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
|
||||||
if err != nil {
|
assert.NoError(t, err, "can't read residual data on stream")
|
||||||
t.Errorf("can't read residual data on stream: %v", err)
|
assert.Equal(t, testData, recvBuf, "incorrect data read back")
|
||||||
}
|
|
||||||
if !bytes.Equal(testData, recvBuf) {
|
|
||||||
t.Errorf("incorrect data read back")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ func BenchmarkObfs(b *testing.B) {
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
obfsBuf := make([]byte, defaultSendRecvBufSize)
|
obfsBuf := make([]byte, len(testPayload)*2)
|
||||||
|
|
||||||
var key [32]byte
|
var key [32]byte
|
||||||
rand.Read(key[:])
|
rand.Read(key[:])
|
||||||
|
|
@ -211,7 +211,7 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
|
|
||||||
obfsBuf := make([]byte, defaultSendRecvBufSize)
|
obfsBuf := make([]byte, len(testPayload)*2)
|
||||||
|
|
||||||
var key [32]byte
|
var key [32]byte
|
||||||
rand.Read(key[:])
|
rand.Read(key[:])
|
||||||
|
|
|
||||||
|
|
@ -25,4 +25,4 @@ type recvBuffer interface {
|
||||||
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
|
// size we want the amount of unread data in buffer to grow before recvBuffer.Write blocks.
|
||||||
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
|
// If the buffer grows larger than what the system's memory can offer at the time of recvBuffer.Write,
|
||||||
// a panic will happen.
|
// a panic will happen.
|
||||||
const recvBufferSizeLimit = defaultSendRecvBufSize << 12
|
const recvBufferSizeLimit = 1<<31 - 1
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,8 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
acceptBacklog = 1024
|
acceptBacklog = 1024
|
||||||
// TODO: will this be a signature?
|
|
||||||
defaultSendRecvBufSize = 20480
|
|
||||||
defaultInactivityTimeout = 30 * time.Second
|
defaultInactivityTimeout = 30 * time.Second
|
||||||
|
defaultMaxOnWireSize = 1<<14 + 256 // https://tools.ietf.org/html/rfc8446#section-5.2
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrBrokenSession = errors.New("broken session")
|
var ErrBrokenSession = errors.New("broken session")
|
||||||
|
|
@ -24,8 +23,6 @@ var errRepeatSessionClosing = errors.New("trying to close a closed session")
|
||||||
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
var errRepeatStreamClosing = errors.New("trying to close a closed stream")
|
||||||
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
var errNoMultiplex = errors.New("a singleplexing session can have only one stream")
|
||||||
|
|
||||||
type switchboardStrategy int
|
|
||||||
|
|
||||||
type SessionConfig struct {
|
type SessionConfig struct {
|
||||||
Obfuscator
|
Obfuscator
|
||||||
|
|
||||||
|
|
@ -40,12 +37,6 @@ type SessionConfig struct {
|
||||||
// maximum size of an obfuscated frame, including headers and overhead
|
// maximum size of an obfuscated frame, including headers and overhead
|
||||||
MsgOnWireSizeLimit int
|
MsgOnWireSizeLimit int
|
||||||
|
|
||||||
// StreamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
|
|
||||||
StreamSendBufferSize int
|
|
||||||
// ConnReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
|
|
||||||
// switchboard.deplex)
|
|
||||||
ConnReceiveBufferSize int
|
|
||||||
|
|
||||||
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
|
// InactivityTimeout sets the duration a Session waits while it has no active streams before it closes itself
|
||||||
InactivityTimeout time.Duration
|
InactivityTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
@ -82,11 +73,17 @@ type Session struct {
|
||||||
|
|
||||||
closed uint32
|
closed uint32
|
||||||
|
|
||||||
terminalMsg atomic.Value
|
terminalMsgSetter sync.Once
|
||||||
|
terminalMsg string
|
||||||
|
|
||||||
// the max size passed to Write calls before it splits it into multiple frames
|
// the max size passed to Write calls before it splits it into multiple frames
|
||||||
// i.e. the max size a piece of data can fit into a Frame.Payload
|
// i.e. the max size a piece of data can fit into a Frame.Payload
|
||||||
maxStreamUnitWrite int
|
maxStreamUnitWrite int
|
||||||
|
// streamSendBufferSize sets the buffer size used to send data from a Stream (Stream.obfsBuf)
|
||||||
|
streamSendBufferSize int
|
||||||
|
// connReceiveBufferSize sets the buffer size used to receive data from an underlying Conn (allocated in
|
||||||
|
// switchboard.deplex)
|
||||||
|
connReceiveBufferSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeSession(id uint32, config SessionConfig) *Session {
|
func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
|
|
@ -103,23 +100,19 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
if config.Valve == nil {
|
if config.Valve == nil {
|
||||||
sesh.Valve = UNLIMITED_VALVE
|
sesh.Valve = UNLIMITED_VALVE
|
||||||
}
|
}
|
||||||
if config.StreamSendBufferSize <= 0 {
|
|
||||||
sesh.StreamSendBufferSize = defaultSendRecvBufSize
|
|
||||||
}
|
|
||||||
if config.ConnReceiveBufferSize <= 0 {
|
|
||||||
sesh.ConnReceiveBufferSize = defaultSendRecvBufSize
|
|
||||||
}
|
|
||||||
if config.MsgOnWireSizeLimit <= 0 {
|
if config.MsgOnWireSizeLimit <= 0 {
|
||||||
sesh.MsgOnWireSizeLimit = defaultSendRecvBufSize - 1024
|
sesh.MsgOnWireSizeLimit = defaultMaxOnWireSize
|
||||||
}
|
}
|
||||||
if config.InactivityTimeout == 0 {
|
if config.InactivityTimeout == 0 {
|
||||||
sesh.InactivityTimeout = defaultInactivityTimeout
|
sesh.InactivityTimeout = defaultInactivityTimeout
|
||||||
}
|
}
|
||||||
// todo: validation. this must be smaller than StreamSendBufferSize
|
|
||||||
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
|
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.maxOverhead
|
||||||
|
sesh.streamSendBufferSize = sesh.MsgOnWireSizeLimit
|
||||||
|
sesh.connReceiveBufferSize = 20480 // for backwards compatibility
|
||||||
|
|
||||||
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
|
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
|
||||||
b := make([]byte, sesh.StreamSendBufferSize)
|
b := make([]byte, sesh.streamSendBufferSize)
|
||||||
return &b
|
return &b
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
@ -187,7 +180,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||||
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
||||||
}
|
}
|
||||||
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
_ = s.recvBuf.Close() // recvBuf.Close should not return error
|
||||||
|
|
||||||
if active {
|
if active {
|
||||||
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
|
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
|
@ -271,16 +264,13 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) SetTerminalMsg(msg string) {
|
func (sesh *Session) SetTerminalMsg(msg string) {
|
||||||
sesh.terminalMsg.Store(msg)
|
sesh.terminalMsgSetter.Do(func() {
|
||||||
|
sesh.terminalMsg = msg
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) TerminalMsg() string {
|
func (sesh *Session) TerminalMsg() string {
|
||||||
msg := sesh.terminalMsg.Load()
|
return sesh.terminalMsg
|
||||||
if msg != nil {
|
|
||||||
return msg.(string)
|
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) closeSession() error {
|
func (sesh *Session) closeSession() error {
|
||||||
|
|
@ -292,14 +282,12 @@ func (sesh *Session) closeSession() error {
|
||||||
sesh.streamsM.Lock()
|
sesh.streamsM.Lock()
|
||||||
close(sesh.acceptCh)
|
close(sesh.acceptCh)
|
||||||
for id, stream := range sesh.streams {
|
for id, stream := range sesh.streams {
|
||||||
if stream == nil {
|
if stream != nil && atomic.CompareAndSwapUint32(&stream.closed, 0, 1) {
|
||||||
continue
|
_ = stream.recvBuf.Close() // will not block
|
||||||
}
|
|
||||||
atomic.StoreUint32(&stream.closed, 1)
|
|
||||||
_ = stream.getRecvBuf().Close() // will not block
|
|
||||||
delete(sesh.streams, id)
|
delete(sesh.streams, id)
|
||||||
sesh.streamCountDecr()
|
sesh.streamCountDecr()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
sesh.streamsM.Unlock()
|
sesh.streamsM.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -339,7 +327,7 @@ func (sesh *Session) Close() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sesh.sb.send((*buf)[:i], new(uint32))
|
_, err = sesh.sb.send((*buf)[:i], new(net.Conn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -534,7 +534,7 @@ func TestSession_timeoutAfter(t *testing.T) {
|
||||||
func BenchmarkRecvDataFromRemote(b *testing.B) {
|
func BenchmarkRecvDataFromRemote(b *testing.B) {
|
||||||
testPayload := make([]byte, testPayloadLen)
|
testPayload := make([]byte, testPayloadLen)
|
||||||
rand.Read(testPayload)
|
rand.Read(testPayload)
|
||||||
f := &Frame{
|
f := Frame{
|
||||||
1,
|
1,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
|
@ -544,12 +544,13 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
|
|
||||||
const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic
|
const maxIter = 500_000 // run with -benchtime 500000x to avoid index out of bounds panic
|
||||||
for name, ep := range encryptionMethods {
|
for name, ep := range encryptionMethods {
|
||||||
ep := ep
|
ep := ep
|
||||||
b.Run(name, func(b *testing.B) {
|
b.Run(name, func(b *testing.B) {
|
||||||
for seshType, seshConfig := range seshConfigs {
|
for seshType, seshConfig := range seshConfigs {
|
||||||
b.Run(seshType, func(b *testing.B) {
|
b.Run(seshType, func(b *testing.B) {
|
||||||
|
f := f
|
||||||
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
|
seshConfig.Obfuscator, _ = MakeObfuscator(ep, sessionKey)
|
||||||
sesh := MakeSession(0, seshConfig)
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
|
|
@ -561,7 +562,7 @@ func BenchmarkRecvDataFromRemote(b *testing.B) {
|
||||||
binaryFrames := [maxIter][]byte{}
|
binaryFrames := [maxIter][]byte{}
|
||||||
for i := 0; i < maxIter; i++ {
|
for i := 0; i < maxIter; i++ {
|
||||||
obfsBuf := make([]byte, obfsBufLen)
|
obfsBuf := make([]byte, obfsBufLen)
|
||||||
n, _ := sesh.obfuscate(f, obfsBuf, 0)
|
n, _ := sesh.obfuscate(&f, obfsBuf, 0)
|
||||||
binaryFrames[i] = obfsBuf[:n]
|
binaryFrames[i] = obfsBuf[:n]
|
||||||
f.Seq++
|
f.Seq++
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,8 @@ type Stream struct {
|
||||||
|
|
||||||
session *Session
|
session *Session
|
||||||
|
|
||||||
allocIdempot sync.Once
|
|
||||||
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
|
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
|
||||||
// been read by the consumer through Read or WriteTo. Lazily allocated
|
// been read by the consumer through Read or WriteTo.
|
||||||
recvBuf recvBuffer
|
recvBuf recvBuffer
|
||||||
|
|
||||||
writingM sync.Mutex
|
writingM sync.Mutex
|
||||||
|
|
@ -40,7 +39,7 @@ type Stream struct {
|
||||||
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
|
// recvBuffer (implemented by streamBuffer under ordered mode) will not receive out-of-order packets
|
||||||
// so it won't have to use its priority queue to sort it.
|
// so it won't have to use its priority queue to sort it.
|
||||||
// This is not used in unordered connection mode
|
// This is not used in unordered connection mode
|
||||||
assignedConnId uint32
|
assignedConn net.Conn
|
||||||
|
|
||||||
readFromTimeout time.Duration
|
readFromTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
@ -56,25 +55,20 @@ func makeStream(sesh *Session, id uint32) *Stream {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if sesh.Unordered {
|
||||||
|
stream.recvBuf = NewDatagramBufferedPipe()
|
||||||
|
} else {
|
||||||
|
stream.recvBuf = NewStreamBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
||||||
|
|
||||||
func (s *Stream) getRecvBuf() recvBuffer {
|
|
||||||
s.allocIdempot.Do(func() {
|
|
||||||
if s.session.Unordered {
|
|
||||||
s.recvBuf = NewDatagramBufferedPipe()
|
|
||||||
} else {
|
|
||||||
s.recvBuf = NewStreamBuffer()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return s.recvBuf
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive a readily deobfuscated Frame so its payload can later be Read
|
// receive a readily deobfuscated Frame so its payload can later be Read
|
||||||
func (s *Stream) recvFrame(frame *Frame) error {
|
func (s *Stream) recvFrame(frame *Frame) error {
|
||||||
toBeClosed, err := s.getRecvBuf().Write(frame)
|
toBeClosed, err := s.recvBuf.Write(frame)
|
||||||
if toBeClosed {
|
if toBeClosed {
|
||||||
err = s.passiveClose()
|
err = s.passiveClose()
|
||||||
if errors.Is(err, errRepeatStreamClosing) {
|
if errors.Is(err, errRepeatStreamClosing) {
|
||||||
|
|
@ -93,7 +87,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = s.getRecvBuf().Read(buf)
|
n, err = s.recvBuf.Read(buf)
|
||||||
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return n, ErrBrokenStream
|
return n, ErrBrokenStream
|
||||||
|
|
@ -104,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||||
// WriteTo continuously write data Stream has received into the writer w.
|
// WriteTo continuously write data Stream has received into the writer w.
|
||||||
func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
||||||
// will keep writing until the underlying buffer is closed
|
// will keep writing until the underlying buffer is closed
|
||||||
n, err := s.getRecvBuf().WriteTo(w)
|
n, err := s.recvBuf.WriteTo(w)
|
||||||
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return n, ErrBrokenStream
|
return n, ErrBrokenStream
|
||||||
|
|
@ -119,7 +113,7 @@ func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
|
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errBrokenSwitchboard {
|
if err == errBrokenSwitchboard {
|
||||||
s.session.SetTerminalMsg(err.Error())
|
s.session.SetTerminalMsg(err.Error())
|
||||||
|
|
@ -215,8 +209,8 @@ func (s *Stream) Close() error {
|
||||||
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
|
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
|
||||||
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
|
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
|
||||||
|
|
||||||
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) }
|
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
|
||||||
func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil }
|
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
|
||||||
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
|
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
|
||||||
|
|
||||||
var errNotImplemented = errors.New("Not implemented")
|
var errNotImplemented = errors.New("Not implemented")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package multiplex
|
package multiplex
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"github.com/stretchr/testify/assert"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -13,49 +13,15 @@ func TestPipeRW(t *testing.T) {
|
||||||
pipe := NewStreamBufferedPipe()
|
pipe := NewStreamBufferedPipe()
|
||||||
b := []byte{0x01, 0x02, 0x03}
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
n, err := pipe.Write(b)
|
n, err := pipe.Write(b)
|
||||||
if n != len(b) {
|
assert.NoError(t, err, "simple write")
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes written")
|
||||||
"For", "number of bytes written",
|
|
||||||
"expecting", len(b),
|
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple write",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err = pipe.Read(b2)
|
n, err = pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err, "simple read")
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes read")
|
||||||
"For", "number of bytes read",
|
|
||||||
"expecting", len(b),
|
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
assert.Equal(t, b, b2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadBlock(t *testing.T) {
|
func TestReadBlock(t *testing.T) {
|
||||||
|
|
@ -67,30 +33,10 @@ func TestReadBlock(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err, "blocked read")
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes read after block")
|
||||||
"For", "number of bytes read after block",
|
|
||||||
"expecting", len(b),
|
assert.Equal(t, b, b2)
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "blocked read",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "blocked read",
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPartialRead(t *testing.T) {
|
func TestPartialRead(t *testing.T) {
|
||||||
|
|
@ -99,54 +45,17 @@ func TestPartialRead(t *testing.T) {
|
||||||
pipe.Write(b)
|
pipe.Write(b)
|
||||||
b1 := make([]byte, 1)
|
b1 := make([]byte, 1)
|
||||||
n, err := pipe.Read(b1)
|
n, err := pipe.Read(b1)
|
||||||
if n != len(b1) {
|
assert.NoError(t, err, "partial read of 1")
|
||||||
t.Error(
|
assert.Equal(t, len(b1), n, "number of bytes in partial read of 1")
|
||||||
"For", "number of bytes in partial read of 1",
|
|
||||||
"expecting", len(b1),
|
assert.Equal(t, b[0], b1[0])
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "partial read of 1",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if b1[0] != b[0] {
|
|
||||||
t.Error(
|
|
||||||
"For", "partial read of 1",
|
|
||||||
"expecting", b[0],
|
|
||||||
"got", b1[0],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
b2 := make([]byte, 2)
|
b2 := make([]byte, 2)
|
||||||
n, err = pipe.Read(b2)
|
n, err = pipe.Read(b2)
|
||||||
if n != len(b2) {
|
assert.NoError(t, err, "partial read of 2")
|
||||||
t.Error(
|
assert.Equal(t, len(b2), n, "number of bytes in partial read of 2")
|
||||||
"For", "number of bytes in partial read of 2",
|
|
||||||
"expecting", len(b2),
|
assert.Equal(t, b[1:], b2)
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "partial read of 2",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b[1:], b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "partial read of 2",
|
|
||||||
"expecting", b[1:],
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadAfterClose(t *testing.T) {
|
func TestReadAfterClose(t *testing.T) {
|
||||||
|
|
@ -156,29 +65,10 @@ func TestReadAfterClose(t *testing.T) {
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
pipe.Close()
|
pipe.Close()
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
assert.NoError(t, err, "simple read")
|
||||||
t.Error(
|
assert.Equal(t, len(b), n, "number of bytes read")
|
||||||
"For", "number of bytes read",
|
|
||||||
"expecting", len(b),
|
assert.Equal(t, b, b2)
|
||||||
"got", n,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", "nil error",
|
|
||||||
"got", err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, b2) {
|
|
||||||
t.Error(
|
|
||||||
"For", "simple read",
|
|
||||||
"expecting", b,
|
|
||||||
"got", b2,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkBufferedPipe_RW(b *testing.B) {
|
func BenchmarkBufferedPipe_RW(b *testing.B) {
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type switchboardStrategy int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FIXED_CONN_MAPPING switchboardStrategy = iota
|
fixedConnMapping switchboardStrategy = iota
|
||||||
UNIFORM_SPREAD
|
uniformSpread
|
||||||
)
|
)
|
||||||
|
|
||||||
// switchboard represents the connection pool. It is responsible for managing
|
// switchboard represents the connection pool. It is responsible for managing
|
||||||
|
|
@ -28,10 +30,8 @@ type switchboard struct {
|
||||||
valve Valve
|
valve Valve
|
||||||
strategy switchboardStrategy
|
strategy switchboardStrategy
|
||||||
|
|
||||||
// map of connId to net.Conn
|
|
||||||
conns sync.Map
|
conns sync.Map
|
||||||
numConns uint32
|
connsCount uint32
|
||||||
nextConnId uint32
|
|
||||||
randPool sync.Pool
|
randPool sync.Pool
|
||||||
|
|
||||||
broken uint32
|
broken uint32
|
||||||
|
|
@ -41,15 +41,14 @@ func makeSwitchboard(sesh *Session) *switchboard {
|
||||||
var strategy switchboardStrategy
|
var strategy switchboardStrategy
|
||||||
if sesh.Unordered {
|
if sesh.Unordered {
|
||||||
log.Debug("Connection is unordered")
|
log.Debug("Connection is unordered")
|
||||||
strategy = UNIFORM_SPREAD
|
strategy = uniformSpread
|
||||||
} else {
|
} else {
|
||||||
strategy = FIXED_CONN_MAPPING
|
strategy = fixedConnMapping
|
||||||
}
|
}
|
||||||
sb := &switchboard{
|
sb := &switchboard{
|
||||||
session: sesh,
|
session: sesh,
|
||||||
strategy: strategy,
|
strategy: strategy,
|
||||||
valve: sesh.Valve,
|
valve: sesh.Valve,
|
||||||
nextConnId: 1,
|
|
||||||
randPool: sync.Pool{New: func() interface{} {
|
randPool: sync.Pool{New: func() interface{} {
|
||||||
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
return rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
|
||||||
}},
|
}},
|
||||||
|
|
@ -59,88 +58,85 @@ func makeSwitchboard(sesh *Session) *switchboard {
|
||||||
|
|
||||||
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
var errBrokenSwitchboard = errors.New("the switchboard is broken")
|
||||||
|
|
||||||
func (sb *switchboard) connsCount() int {
|
|
||||||
return int(atomic.LoadUint32(&sb.numConns))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sb *switchboard) addConn(conn net.Conn) {
|
func (sb *switchboard) addConn(conn net.Conn) {
|
||||||
connId := atomic.AddUint32(&sb.nextConnId, 1) - 1
|
atomic.AddUint32(&sb.connsCount, 1)
|
||||||
atomic.AddUint32(&sb.numConns, 1)
|
sb.conns.Store(conn, conn)
|
||||||
sb.conns.Store(connId, conn)
|
go sb.deplex(conn)
|
||||||
go sb.deplex(connId, conn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// a pointer to connId is passed here so that the switchboard can reassign it if that connId isn't usable
|
// a pointer to assignedConn is passed here so that the switchboard can reassign it if that conn isn't usable
|
||||||
func (sb *switchboard) send(data []byte, connId *uint32) (n int, err error) {
|
func (sb *switchboard) send(data []byte, assignedConn *net.Conn) (n int, err error) {
|
||||||
sb.valve.txWait(len(data))
|
sb.valve.txWait(len(data))
|
||||||
if atomic.LoadUint32(&sb.broken) == 1 || sb.connsCount() == 0 {
|
if atomic.LoadUint32(&sb.broken) == 1 {
|
||||||
return 0, errBrokenSwitchboard
|
return 0, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
|
|
||||||
var conn net.Conn
|
var conn net.Conn
|
||||||
switch sb.strategy {
|
switch sb.strategy {
|
||||||
case UNIFORM_SPREAD:
|
case uniformSpread:
|
||||||
_, conn, err = sb.pickRandConn()
|
conn, err = sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errBrokenSwitchboard
|
return 0, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
case FIXED_CONN_MAPPING:
|
n, err = conn.Write(data)
|
||||||
connI, ok := sb.conns.Load(*connId)
|
|
||||||
if ok {
|
|
||||||
conn = connI.(net.Conn)
|
|
||||||
} else {
|
|
||||||
var newConnId uint32
|
|
||||||
newConnId, conn, err = sb.pickRandConn()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errBrokenSwitchboard
|
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
|
||||||
|
sb.session.passiveClose()
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
*connId = newConnId
|
case fixedConnMapping:
|
||||||
|
conn = *assignedConn
|
||||||
|
if conn == nil {
|
||||||
|
conn, err = sb.pickRandConn()
|
||||||
|
if err != nil {
|
||||||
|
sb.session.SetTerminalMsg("failed to pick a connection " + err.Error())
|
||||||
|
sb.session.passiveClose()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
*assignedConn = conn
|
||||||
|
}
|
||||||
|
n, err = conn.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
sb.session.SetTerminalMsg("failed to send to remote " + err.Error())
|
||||||
|
sb.session.passiveClose()
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("unsupported traffic distribution strategy")
|
return 0, errors.New("unsupported traffic distribution strategy")
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
sb.conns.Delete(*connId)
|
|
||||||
sb.session.SetTerminalMsg("failed to write to remote " + err.Error())
|
|
||||||
sb.session.passiveClose()
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
sb.valve.AddTx(int64(n))
|
sb.valve.AddTx(int64(n))
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns a random connId
|
// returns a random connId
|
||||||
func (sb *switchboard) pickRandConn() (uint32, net.Conn, error) {
|
func (sb *switchboard) pickRandConn() (net.Conn, error) {
|
||||||
connCount := sb.connsCount()
|
if atomic.LoadUint32(&sb.broken) == 1 {
|
||||||
if atomic.LoadUint32(&sb.broken) == 1 || connCount == 0 {
|
return nil, errBrokenSwitchboard
|
||||||
return 0, nil, errBrokenSwitchboard
|
}
|
||||||
|
|
||||||
|
connsCount := atomic.LoadUint32(&sb.connsCount)
|
||||||
|
if connsCount == 0 {
|
||||||
|
return nil, errBrokenSwitchboard
|
||||||
}
|
}
|
||||||
|
|
||||||
// there is no guarantee that sb.conns still has the same amount of entries
|
|
||||||
// between the count loop and the pick loop
|
|
||||||
// so if the r > len(sb.conns) at the point of range call, the last visited element is picked
|
|
||||||
var id uint32
|
|
||||||
var conn net.Conn
|
|
||||||
randReader := sb.randPool.Get().(*rand.Rand)
|
randReader := sb.randPool.Get().(*rand.Rand)
|
||||||
r := randReader.Intn(connCount)
|
|
||||||
|
r := randReader.Intn(int(connsCount))
|
||||||
sb.randPool.Put(randReader)
|
sb.randPool.Put(randReader)
|
||||||
|
|
||||||
var c int
|
var c int
|
||||||
sb.conns.Range(func(connIdI, connI interface{}) bool {
|
var ret net.Conn
|
||||||
|
sb.conns.Range(func(_, conn interface{}) bool {
|
||||||
if r == c {
|
if r == c {
|
||||||
id = connIdI.(uint32)
|
ret = conn.(net.Conn)
|
||||||
conn = connI.(net.Conn)
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
c++
|
c++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
// if len(sb.conns) is 0
|
|
||||||
if conn == nil {
|
return ret, nil
|
||||||
return 0, nil, errBrokenSwitchboard
|
|
||||||
}
|
|
||||||
return id, conn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// actively triggered by session.Close()
|
// actively triggered by session.Close()
|
||||||
|
|
@ -148,26 +144,24 @@ func (sb *switchboard) closeAll() {
|
||||||
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
if !atomic.CompareAndSwapUint32(&sb.broken, 0, 1) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sb.conns.Range(func(key, connI interface{}) bool {
|
sb.conns.Range(func(_, conn interface{}) bool {
|
||||||
conn := connI.(net.Conn)
|
conn.(net.Conn).Close()
|
||||||
conn.Close()
|
sb.conns.Delete(conn)
|
||||||
sb.conns.Delete(key)
|
atomic.AddUint32(&sb.connsCount, ^uint32(0))
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// deplex function costantly reads from a TCP connection
|
// deplex function costantly reads from a TCP connection
|
||||||
func (sb *switchboard) deplex(connId uint32, conn net.Conn) {
|
func (sb *switchboard) deplex(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
buf := make([]byte, sb.session.ConnReceiveBufferSize)
|
buf := make([]byte, sb.session.connReceiveBufferSize)
|
||||||
for {
|
for {
|
||||||
n, err := conn.Read(buf)
|
n, err := conn.Read(buf)
|
||||||
sb.valve.rxWait(n)
|
sb.valve.rxWait(n)
|
||||||
sb.valve.AddRx(int64(n))
|
sb.valve.AddRx(int64(n))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
|
log.Debugf("a connection for session %v has closed: %v", sb.session.id, err)
|
||||||
sb.conns.Delete(connId)
|
|
||||||
atomic.AddUint32(&sb.numConns, ^uint32(0))
|
|
||||||
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
|
sb.session.SetTerminalMsg("a connection has dropped unexpectedly")
|
||||||
sb.session.passiveClose()
|
sb.session.passiveClose()
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -14,14 +15,14 @@ func TestSwitchboard_Send(t *testing.T) {
|
||||||
sesh := MakeSession(0, seshConfig)
|
sesh := MakeSession(0, seshConfig)
|
||||||
hole0 := connutil.Discard()
|
hole0 := connutil.Discard()
|
||||||
sesh.sb.addConn(hole0)
|
sesh.sb.addConn(hole0)
|
||||||
connId, _, err := sesh.sb.pickRandConn()
|
conn, err := sesh.sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to get a random conn", err)
|
t.Error("failed to get a random conn", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
data := make([]byte, 1000)
|
data := make([]byte, 1000)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
_, err = sesh.sb.send(data, &connId)
|
_, err = sesh.sb.send(data, &conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
|
@ -29,23 +30,23 @@ func TestSwitchboard_Send(t *testing.T) {
|
||||||
|
|
||||||
hole1 := connutil.Discard()
|
hole1 := connutil.Discard()
|
||||||
sesh.sb.addConn(hole1)
|
sesh.sb.addConn(hole1)
|
||||||
connId, _, err = sesh.sb.pickRandConn()
|
conn, err = sesh.sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to get a random conn", err)
|
t.Error("failed to get a random conn", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = sesh.sb.send(data, &connId)
|
_, err = sesh.sb.send(data, &conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
connId, _, err = sesh.sb.pickRandConn()
|
conn, err = sesh.sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to get a random conn", err)
|
t.Error("failed to get a random conn", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = sesh.sb.send(data, &connId)
|
_, err = sesh.sb.send(data, &conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
|
@ -71,7 +72,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
|
||||||
seshConfig := SessionConfig{}
|
seshConfig := SessionConfig{}
|
||||||
sesh := MakeSession(0, seshConfig)
|
sesh := MakeSession(0, seshConfig)
|
||||||
sesh.sb.addConn(hole)
|
sesh.sb.addConn(hole)
|
||||||
connId, _, err := sesh.sb.pickRandConn()
|
conn, err := sesh.sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Error("failed to get a random conn", err)
|
b.Error("failed to get a random conn", err)
|
||||||
return
|
return
|
||||||
|
|
@ -81,7 +82,7 @@ func BenchmarkSwitchboard_Send(b *testing.B) {
|
||||||
b.SetBytes(int64(len(data)))
|
b.SetBytes(int64(len(data)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
sesh.sb.send(data, &connId)
|
sesh.sb.send(data, &conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -92,7 +93,7 @@ func TestSwitchboard_TxCredit(t *testing.T) {
|
||||||
sesh := MakeSession(0, seshConfig)
|
sesh := MakeSession(0, seshConfig)
|
||||||
hole := connutil.Discard()
|
hole := connutil.Discard()
|
||||||
sesh.sb.addConn(hole)
|
sesh.sb.addConn(hole)
|
||||||
connId, _, err := sesh.sb.pickRandConn()
|
conn, err := sesh.sb.pickRandConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to get a random conn", err)
|
t.Error("failed to get a random conn", err)
|
||||||
return
|
return
|
||||||
|
|
@ -100,10 +101,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
|
||||||
data := make([]byte, 1000)
|
data := make([]byte, 1000)
|
||||||
rand.Read(data)
|
rand.Read(data)
|
||||||
|
|
||||||
t.Run("FIXED CONN MAPPING", func(t *testing.T) {
|
t.Run("fixed conn mapping", func(t *testing.T) {
|
||||||
*sesh.sb.valve.(*LimitedValve).tx = 0
|
*sesh.sb.valve.(*LimitedValve).tx = 0
|
||||||
sesh.sb.strategy = FIXED_CONN_MAPPING
|
sesh.sb.strategy = fixedConnMapping
|
||||||
n, err := sesh.sb.send(data[:10], &connId)
|
n, err := sesh.sb.send(data[:10], &conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
|
@ -116,10 +117,10 @@ func TestSwitchboard_TxCredit(t *testing.T) {
|
||||||
t.Error("tx credit didn't increase by 10")
|
t.Error("tx credit didn't increase by 10")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("UNIFORM", func(t *testing.T) {
|
t.Run("uniform spread", func(t *testing.T) {
|
||||||
*sesh.sb.valve.(*LimitedValve).tx = 0
|
*sesh.sb.valve.(*LimitedValve).tx = 0
|
||||||
sesh.sb.strategy = UNIFORM_SPREAD
|
sesh.sb.strategy = uniformSpread
|
||||||
n, err := sesh.sb.send(data[:10], &connId)
|
n, err := sesh.sb.send(data[:10], &conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
|
|
@ -173,13 +174,13 @@ func TestSwitchboard_ConnsCount(t *testing.T) {
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if sesh.sb.connsCount() != 1000 {
|
if atomic.LoadUint32(&sesh.sb.connsCount) != 1000 {
|
||||||
t.Error("connsCount incorrect")
|
t.Error("connsCount incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
sesh.sb.closeAll()
|
sesh.sb.closeAll()
|
||||||
|
|
||||||
assert.Eventuallyf(t, func() bool {
|
assert.Eventuallyf(t, func() bool {
|
||||||
return sesh.sb.connsCount() == 0
|
return atomic.LoadUint32(&sesh.sb.connsCount) == 0
|
||||||
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", sesh.sb.connsCount())
|
}, time.Second, 10*time.Millisecond, "connsCount incorrect: %v", atomic.LoadUint32(&sesh.sb.connsCount))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ swagger: '2.0'
|
||||||
info:
|
info:
|
||||||
description: |
|
description: |
|
||||||
This is the API of Cloak server
|
This is the API of Cloak server
|
||||||
version: 1.0.0
|
version: 0.0.2
|
||||||
title: Cloak Server
|
title: Cloak Server
|
||||||
contact:
|
contact:
|
||||||
email: cbeuw.andy@gmail.com
|
email: cbeuw.andy@gmail.com
|
||||||
|
|
@ -12,8 +12,6 @@ info:
|
||||||
# host: petstore.swagger.io
|
# host: petstore.swagger.io
|
||||||
# basePath: /v2
|
# basePath: /v2
|
||||||
tags:
|
tags:
|
||||||
- name: admin
|
|
||||||
description: Endpoints used by the host administrators
|
|
||||||
- name: users
|
- name: users
|
||||||
description: Operations related to user controls by admin
|
description: Operations related to user controls by admin
|
||||||
# schemes:
|
# schemes:
|
||||||
|
|
@ -22,7 +20,6 @@ paths:
|
||||||
/admin/users:
|
/admin/users:
|
||||||
get:
|
get:
|
||||||
tags:
|
tags:
|
||||||
- admin
|
|
||||||
- users
|
- users
|
||||||
summary: Show all users
|
summary: Show all users
|
||||||
description: Returns an array of all UserInfo
|
description: Returns an array of all UserInfo
|
||||||
|
|
@ -41,7 +38,6 @@ paths:
|
||||||
/admin/users/{UID}:
|
/admin/users/{UID}:
|
||||||
get:
|
get:
|
||||||
tags:
|
tags:
|
||||||
- admin
|
|
||||||
- users
|
- users
|
||||||
summary: Show userinfo by UID
|
summary: Show userinfo by UID
|
||||||
description: Returns a UserInfo object
|
description: Returns a UserInfo object
|
||||||
|
|
@ -68,7 +64,6 @@ paths:
|
||||||
description: internal error
|
description: internal error
|
||||||
post:
|
post:
|
||||||
tags:
|
tags:
|
||||||
- admin
|
|
||||||
- users
|
- users
|
||||||
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
|
summary: Updates the userinfo of the specified user, if the user does not exist, then a new user is created
|
||||||
operationId: writeUserInfo
|
operationId: writeUserInfo
|
||||||
|
|
@ -100,7 +95,6 @@ paths:
|
||||||
description: internal error
|
description: internal error
|
||||||
delete:
|
delete:
|
||||||
tags:
|
tags:
|
||||||
- admin
|
|
||||||
- users
|
- users
|
||||||
summary: Deletes a user
|
summary: Deletes a user
|
||||||
operationId: deleteUser
|
operationId: deleteUser
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,36 @@ func TestWriteUserInfoHlr(t *testing.T) {
|
||||||
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
|
assert.Equalf(t, http.StatusCreated, rr.Code, "response body: %v", rr.Body)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("partial update", func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(marshalled))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||||
|
|
||||||
|
partialUserInfo := UserInfo{
|
||||||
|
UID: mockUID,
|
||||||
|
SessionsCap: JustInt32(10),
|
||||||
|
}
|
||||||
|
partialMarshalled, _ := json.Marshal(partialUserInfo)
|
||||||
|
req, err = http.NewRequest("POST", "/admin/users/"+mockUIDb64, bytes.NewBuffer(partialMarshalled))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("GET", "/admin/users/"+mockUIDb64, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
assert.Equal(t, http.StatusCreated, rr.Code)
|
||||||
|
var got UserInfo
|
||||||
|
err = json.Unmarshal(rr.Body.Bytes(), &got)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expected := mockUserInfo
|
||||||
|
expected.SessionsCap = partialUserInfo.SessionsCap
|
||||||
|
assert.EqualValues(t, expected, got)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("empty parameter", func(t *testing.T) {
|
t.Run("empty parameter", func(t *testing.T) {
|
||||||
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
|
req, err := http.NewRequest("POST", "/admin/users/", bytes.NewBuffer(marshalled))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,7 @@ func (manager *localManager) UploadStatus(uploads []StatusUpdate) ([]StatusRespo
|
||||||
"User no longer exists",
|
"User no longer exists",
|
||||||
}
|
}
|
||||||
responses = append(responses, resp)
|
responses = append(responses, resp)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
|
oldUp := int64(u64(bucket.Get([]byte("UpCredit"))))
|
||||||
|
|
@ -179,12 +180,12 @@ func (manager *localManager) ListAllUsers() (infos []UserInfo, err error) {
|
||||||
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
|
err = tx.ForEach(func(UID []byte, bucket *bolt.Bucket) error {
|
||||||
var uinfo UserInfo
|
var uinfo UserInfo
|
||||||
uinfo.UID = UID
|
uinfo.UID = UID
|
||||||
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap"))))
|
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
|
||||||
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate"))))
|
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
|
||||||
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate"))))
|
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
|
||||||
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
|
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
|
||||||
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
|
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
|
||||||
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
|
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
|
||||||
infos = append(infos, uinfo)
|
infos = append(infos, uinfo)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
@ -200,41 +201,53 @@ func (manager *localManager) GetUserInfo(UID []byte) (uinfo UserInfo, err error)
|
||||||
return ErrUserNotFound
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
uinfo.UID = UID
|
uinfo.UID = UID
|
||||||
uinfo.SessionsCap = int32(u32(bucket.Get([]byte("SessionsCap"))))
|
uinfo.SessionsCap = JustInt32(int32(u32(bucket.Get([]byte("SessionsCap")))))
|
||||||
uinfo.UpRate = int64(u64(bucket.Get([]byte("UpRate"))))
|
uinfo.UpRate = JustInt64(int64(u64(bucket.Get([]byte("UpRate")))))
|
||||||
uinfo.DownRate = int64(u64(bucket.Get([]byte("DownRate"))))
|
uinfo.DownRate = JustInt64(int64(u64(bucket.Get([]byte("DownRate")))))
|
||||||
uinfo.UpCredit = int64(u64(bucket.Get([]byte("UpCredit"))))
|
uinfo.UpCredit = JustInt64(int64(u64(bucket.Get([]byte("UpCredit")))))
|
||||||
uinfo.DownCredit = int64(u64(bucket.Get([]byte("DownCredit"))))
|
uinfo.DownCredit = JustInt64(int64(u64(bucket.Get([]byte("DownCredit")))))
|
||||||
uinfo.ExpiryTime = int64(u64(bucket.Get([]byte("ExpiryTime"))))
|
uinfo.ExpiryTime = JustInt64(int64(u64(bucket.Get([]byte("ExpiryTime")))))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *localManager) WriteUserInfo(uinfo UserInfo) (err error) {
|
func (manager *localManager) WriteUserInfo(u UserInfo) (err error) {
|
||||||
err = manager.db.Update(func(tx *bolt.Tx) error {
|
err = manager.db.Update(func(tx *bolt.Tx) error {
|
||||||
bucket, err := tx.CreateBucketIfNotExists(uinfo.UID)
|
bucket, err := tx.CreateBucketIfNotExists(u.UID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("SessionsCap"), i32ToB(int32(uinfo.SessionsCap))); err != nil {
|
if u.SessionsCap != nil {
|
||||||
|
if err = bucket.Put([]byte("SessionsCap"), i32ToB(*u.SessionsCap)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("UpRate"), i64ToB(uinfo.UpRate)); err != nil {
|
}
|
||||||
|
if u.UpRate != nil {
|
||||||
|
if err = bucket.Put([]byte("UpRate"), i64ToB(*u.UpRate)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("DownRate"), i64ToB(uinfo.DownRate)); err != nil {
|
}
|
||||||
|
if u.DownRate != nil {
|
||||||
|
if err = bucket.Put([]byte("DownRate"), i64ToB(*u.DownRate)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("UpCredit"), i64ToB(uinfo.UpCredit)); err != nil {
|
}
|
||||||
|
if u.UpCredit != nil {
|
||||||
|
if err = bucket.Put([]byte("UpCredit"), i64ToB(*u.UpCredit)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("DownCredit"), i64ToB(uinfo.DownCredit)); err != nil {
|
}
|
||||||
|
if u.DownCredit != nil {
|
||||||
|
if err = bucket.Put([]byte("DownCredit"), i64ToB(*u.DownCredit)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(uinfo.ExpiryTime)); err != nil {
|
}
|
||||||
|
if u.ExpiryTime != nil {
|
||||||
|
if err = bucket.Put([]byte("ExpiryTime"), i64ToB(*u.ExpiryTime)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package usermanager
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"github.com/cbeuw/Cloak/internal/common"
|
"github.com/cbeuw/Cloak/internal/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
|
|
@ -17,12 +18,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||||
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
||||||
var mockUserInfo = UserInfo{
|
var mockUserInfo = UserInfo{
|
||||||
UID: mockUID,
|
UID: mockUID,
|
||||||
SessionsCap: 0,
|
SessionsCap: JustInt32(10),
|
||||||
UpRate: 0,
|
UpRate: JustInt64(100),
|
||||||
DownRate: 0,
|
DownRate: JustInt64(1000),
|
||||||
UpCredit: 0,
|
UpCredit: JustInt64(10000),
|
||||||
DownCredit: 0,
|
DownCredit: JustInt64(100000),
|
||||||
ExpiryTime: 100,
|
ExpiryTime: JustInt64(1000000),
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
|
func makeManager(t *testing.T) (mgr *localManager, cleaner func()) {
|
||||||
|
|
@ -43,6 +44,23 @@ func TestLocalManager_WriteUserInfo(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
got, err := mgr.GetUserInfo(mockUID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, mockUserInfo, got)
|
||||||
|
|
||||||
|
/* Partial update */
|
||||||
|
err = mgr.WriteUserInfo(UserInfo{
|
||||||
|
UID: mockUID,
|
||||||
|
SessionsCap: JustInt32(*mockUserInfo.SessionsCap + 1),
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
expected := mockUserInfo
|
||||||
|
expected.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
|
||||||
|
got, err = mgr.GetUserInfo(mockUID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, expected, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLocalManager_GetUserInfo(t *testing.T) {
|
func TestLocalManager_GetUserInfo(t *testing.T) {
|
||||||
|
|
@ -63,7 +81,7 @@ func TestLocalManager_GetUserInfo(t *testing.T) {
|
||||||
t.Run("update a field", func(t *testing.T) {
|
t.Run("update a field", func(t *testing.T) {
|
||||||
_ = mgr.WriteUserInfo(mockUserInfo)
|
_ = mgr.WriteUserInfo(mockUserInfo)
|
||||||
updatedUserInfo := mockUserInfo
|
updatedUserInfo := mockUserInfo
|
||||||
updatedUserInfo.SessionsCap = mockUserInfo.SessionsCap + 1
|
updatedUserInfo.SessionsCap = JustInt32(*mockUserInfo.SessionsCap + 1)
|
||||||
|
|
||||||
err := mgr.WriteUserInfo(updatedUserInfo)
|
err := mgr.WriteUserInfo(updatedUserInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -103,15 +121,7 @@ func TestLocalManager_DeleteUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var validUserInfo = UserInfo{
|
var validUserInfo = mockUserInfo
|
||||||
UID: mockUID,
|
|
||||||
SessionsCap: 10,
|
|
||||||
UpRate: 100,
|
|
||||||
DownRate: 1000,
|
|
||||||
UpCredit: 10000,
|
|
||||||
DownCredit: 100000,
|
|
||||||
ExpiryTime: 1000000,
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLocalManager_AuthenticateUser(t *testing.T) {
|
func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
||||||
|
|
@ -128,7 +138,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if upRate != validUserInfo.UpRate || downRate != validUserInfo.DownRate {
|
if upRate != *validUserInfo.UpRate || downRate != *validUserInfo.DownRate {
|
||||||
t.Error("wrong up or down rate")
|
t.Error("wrong up or down rate")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -142,7 +152,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||||
|
|
||||||
t.Run("expired user", func(t *testing.T) {
|
t.Run("expired user", func(t *testing.T) {
|
||||||
expiredUserInfo := validUserInfo
|
expiredUserInfo := validUserInfo
|
||||||
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
|
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
|
||||||
|
|
||||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||||
|
|
||||||
|
|
@ -154,7 +164,7 @@ func TestLocalManager_AuthenticateUser(t *testing.T) {
|
||||||
|
|
||||||
t.Run("no credit", func(t *testing.T) {
|
t.Run("no credit", func(t *testing.T) {
|
||||||
creditlessUserInfo := validUserInfo
|
creditlessUserInfo := validUserInfo
|
||||||
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = -1, -1
|
creditlessUserInfo.UpCredit, creditlessUserInfo.DownCredit = JustInt64(-1), JustInt64(-1)
|
||||||
|
|
||||||
_ = mgr.WriteUserInfo(creditlessUserInfo)
|
_ = mgr.WriteUserInfo(creditlessUserInfo)
|
||||||
|
|
||||||
|
|
@ -186,7 +196,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
|
||||||
|
|
||||||
t.Run("expired user", func(t *testing.T) {
|
t.Run("expired user", func(t *testing.T) {
|
||||||
expiredUserInfo := validUserInfo
|
expiredUserInfo := validUserInfo
|
||||||
expiredUserInfo.ExpiryTime = mockWorldState.Now().Add(-10 * time.Second).Unix()
|
expiredUserInfo.ExpiryTime = JustInt64(mockWorldState.Now().Add(-10 * time.Second).Unix())
|
||||||
|
|
||||||
_ = mgr.WriteUserInfo(expiredUserInfo)
|
_ = mgr.WriteUserInfo(expiredUserInfo)
|
||||||
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
err := mgr.AuthoriseNewSession(expiredUserInfo.UID, AuthorisationInfo{NumExistingSessions: 0})
|
||||||
|
|
@ -197,7 +207,7 @@ func TestLocalManager_AuthoriseNewSession(t *testing.T) {
|
||||||
|
|
||||||
t.Run("too many sessions", func(t *testing.T) {
|
t.Run("too many sessions", func(t *testing.T) {
|
||||||
_ = mgr.WriteUserInfo(validUserInfo)
|
_ = mgr.WriteUserInfo(validUserInfo)
|
||||||
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(validUserInfo.SessionsCap + 1)})
|
err := mgr.AuthoriseNewSession(validUserInfo.UID, AuthorisationInfo{NumExistingSessions: int(*validUserInfo.SessionsCap + 1)})
|
||||||
if err != ErrSessionsCapReached {
|
if err != ErrSessionsCapReached {
|
||||||
t.Error("session cap not reached")
|
t.Error("session cap not reached")
|
||||||
}
|
}
|
||||||
|
|
@ -230,10 +240,10 @@ func TestLocalManager_UploadStatus(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updatedUserInfo.UpCredit != validUserInfo.UpCredit-update.UpUsage {
|
if *updatedUserInfo.UpCredit != *validUserInfo.UpCredit-update.UpUsage {
|
||||||
t.Error("up usage incorrect")
|
t.Error("up usage incorrect")
|
||||||
}
|
}
|
||||||
if updatedUserInfo.DownCredit != validUserInfo.DownCredit-update.DownUsage {
|
if *updatedUserInfo.DownCredit != *validUserInfo.DownCredit-update.DownUsage {
|
||||||
t.Error("down usage incorrect")
|
t.Error("down usage incorrect")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -249,7 +259,7 @@ func TestLocalManager_UploadStatus(t *testing.T) {
|
||||||
UID: validUserInfo.UID,
|
UID: validUserInfo.UID,
|
||||||
Active: true,
|
Active: true,
|
||||||
NumSession: 1,
|
NumSession: 1,
|
||||||
UpUsage: validUserInfo.UpCredit + 100,
|
UpUsage: *validUserInfo.UpCredit + 100,
|
||||||
DownUsage: 0,
|
DownUsage: 0,
|
||||||
Timestamp: mockWorldState.Now().Unix(),
|
Timestamp: mockWorldState.Now().Unix(),
|
||||||
},
|
},
|
||||||
|
|
@ -261,19 +271,19 @@ func TestLocalManager_UploadStatus(t *testing.T) {
|
||||||
Active: true,
|
Active: true,
|
||||||
NumSession: 1,
|
NumSession: 1,
|
||||||
UpUsage: 0,
|
UpUsage: 0,
|
||||||
DownUsage: validUserInfo.DownCredit + 100,
|
DownUsage: *validUserInfo.DownCredit + 100,
|
||||||
Timestamp: mockWorldState.Now().Unix(),
|
Timestamp: mockWorldState.Now().Unix(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{"expired",
|
{"expired",
|
||||||
UserInfo{
|
UserInfo{
|
||||||
UID: mockUID,
|
UID: mockUID,
|
||||||
SessionsCap: 10,
|
SessionsCap: JustInt32(10),
|
||||||
UpRate: 0,
|
UpRate: JustInt64(0),
|
||||||
DownRate: 0,
|
DownRate: JustInt64(0),
|
||||||
UpCredit: 0,
|
UpCredit: JustInt64(0),
|
||||||
DownCredit: 0,
|
DownCredit: JustInt64(0),
|
||||||
ExpiryTime: -1,
|
ExpiryTime: JustInt64(-1),
|
||||||
},
|
},
|
||||||
StatusUpdate{
|
StatusUpdate{
|
||||||
UID: mockUserInfo.UID,
|
UID: mockUserInfo.UID,
|
||||||
|
|
@ -318,12 +328,12 @@ func TestLocalManager_ListAllUsers(t *testing.T) {
|
||||||
rand.Read(randUID)
|
rand.Read(randUID)
|
||||||
newUser := UserInfo{
|
newUser := UserInfo{
|
||||||
UID: randUID,
|
UID: randUID,
|
||||||
SessionsCap: rand.Int31(),
|
SessionsCap: JustInt32(rand.Int31()),
|
||||||
UpRate: rand.Int63(),
|
UpRate: JustInt64(rand.Int63()),
|
||||||
DownRate: rand.Int63(),
|
DownRate: JustInt64(rand.Int63()),
|
||||||
UpCredit: rand.Int63(),
|
UpCredit: JustInt64(rand.Int63()),
|
||||||
DownCredit: rand.Int63(),
|
DownCredit: JustInt64(rand.Int63()),
|
||||||
ExpiryTime: rand.Int63(),
|
ExpiryTime: JustInt64(rand.Int63()),
|
||||||
}
|
}
|
||||||
users = append(users, newUser)
|
users = append(users, newUser)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,23 @@ type StatusUpdate struct {
|
||||||
Timestamp int64
|
Timestamp int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MaybeInt32 *int32
|
||||||
|
type MaybeInt64 *int64
|
||||||
|
|
||||||
type UserInfo struct {
|
type UserInfo struct {
|
||||||
UID []byte
|
UID []byte
|
||||||
SessionsCap int32
|
SessionsCap MaybeInt32
|
||||||
UpRate int64
|
UpRate MaybeInt64
|
||||||
DownRate int64
|
DownRate MaybeInt64
|
||||||
UpCredit int64
|
UpCredit MaybeInt64
|
||||||
DownCredit int64
|
DownCredit MaybeInt64
|
||||||
ExpiryTime int64
|
ExpiryTime MaybeInt64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func JustInt32(v int32) MaybeInt32 { return &v }
|
||||||
|
|
||||||
|
func JustInt64(v int64) MaybeInt64 { return &v }
|
||||||
|
|
||||||
type StatusResponse struct {
|
type StatusResponse struct {
|
||||||
UID []byte
|
UID []byte
|
||||||
Action int
|
Action int
|
||||||
|
|
|
||||||
|
|
@ -66,12 +66,12 @@ var mockUID = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||||
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
var mockWorldState = common.WorldOfTime(time.Unix(1, 0))
|
||||||
var validUserInfo = usermanager.UserInfo{
|
var validUserInfo = usermanager.UserInfo{
|
||||||
UID: mockUID,
|
UID: mockUID,
|
||||||
SessionsCap: 10,
|
SessionsCap: usermanager.JustInt32(10),
|
||||||
UpRate: 100,
|
UpRate: usermanager.JustInt64(100),
|
||||||
DownRate: 1000,
|
DownRate: usermanager.JustInt64(1000),
|
||||||
UpCredit: 10000,
|
UpCredit: usermanager.JustInt64(10000),
|
||||||
DownCredit: 100000,
|
DownCredit: usermanager.JustInt64(100000),
|
||||||
ExpiryTime: 1000000,
|
ExpiryTime: usermanager.JustInt64(1000000),
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserPanel_GetUser(t *testing.T) {
|
func TestUserPanel_GetUser(t *testing.T) {
|
||||||
|
|
@ -138,10 +138,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
||||||
if updatedUinfo.DownCredit != validUserInfo.DownCredit-1 {
|
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-1 {
|
||||||
t.Error("down credit incorrect update")
|
t.Error("down credit incorrect update")
|
||||||
}
|
}
|
||||||
if updatedUinfo.UpCredit != validUserInfo.UpCredit-2 {
|
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-2 {
|
||||||
t.Error("up credit incorrect update")
|
t.Error("up credit incorrect update")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -155,10 +155,10 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
|
updatedUinfo, _ = mgr.GetUserInfo(validUserInfo.UID)
|
||||||
if updatedUinfo.DownCredit != validUserInfo.DownCredit-(1+3) {
|
if *updatedUinfo.DownCredit != *validUserInfo.DownCredit-(1+3) {
|
||||||
t.Error("down credit incorrect update")
|
t.Error("down credit incorrect update")
|
||||||
}
|
}
|
||||||
if updatedUinfo.UpCredit != validUserInfo.UpCredit-(2+4) {
|
if *updatedUinfo.UpCredit != *validUserInfo.UpCredit-(2+4) {
|
||||||
t.Error("up credit incorrect update")
|
t.Error("up credit incorrect update")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -170,7 +170,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user.valve.AddTx(validUserInfo.DownCredit + 100)
|
user.valve.AddTx(*validUserInfo.DownCredit + 100)
|
||||||
panel.updateUsageQueue()
|
panel.updateUsageQueue()
|
||||||
err = panel.commitUpdate()
|
err = panel.commitUpdate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -182,7 +182,7 @@ func TestUserPanel_UpdateUsageQueue(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
updatedUinfo, _ := mgr.GetUserInfo(validUserInfo.UID)
|
||||||
if updatedUinfo.DownCredit != -100 {
|
if *updatedUinfo.DownCredit != -100 {
|
||||||
t.Error("down credit not updated correctly after the user has been terminated")
|
t.Error("down credit not updated correctly after the user has been terminated")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -321,7 +321,7 @@ func TestTCPSingleplex(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const echoMsgLen = 16384
|
const echoMsgLen = 1 << 16
|
||||||
go serveTCPEcho(proxyFromCkServerL)
|
go serveTCPEcho(proxyFromCkServerL)
|
||||||
|
|
||||||
proxyConn1, err := proxyToCkClientD.Dial("", "")
|
proxyConn1, err := proxyToCkClientD.Dial("", "")
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ if [ -z "$v" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
|
output="{{.Dir}}-{{.OS}}-{{.Arch}}-$v"
|
||||||
osarch="!darwin/arm !darwin/arm64 !darwin/386"
|
osarch="!darwin/arm !darwin/386"
|
||||||
|
|
||||||
echo "Compiling:"
|
echo "Compiling:"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue