mirror of https://github.com/cbeuw/Cloak
Merge branch 'master' into notsure2
This commit is contained in:
commit
9d79842536
|
|
@ -13,10 +13,10 @@ jobs:
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
export PATH=${PATH}:`go env GOPATH`/bin
|
export PATH=${PATH}:`go env GOPATH`/bin
|
||||||
v=${{ github.ref }} ./release.sh
|
v=${GITHUB_REF#refs/*/} ./release.sh
|
||||||
- name: Release
|
- name: Release
|
||||||
uses: softprops/action-gh-release@v1
|
uses: softprops/action-gh-release@v1
|
||||||
with:
|
with:
|
||||||
files: ./release/ck-*
|
files: release/*
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
14
README.md
14
README.md
|
|
@ -103,15 +103,13 @@ Example:
|
||||||
|
|
||||||
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
|
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
|
||||||
|
|
||||||
`AdminUID` is the UID of the admin user in base64.
|
|
||||||
|
|
||||||
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
|
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
|
||||||
|
|
||||||
`DatabasePath` is the path to `userinfo.db`. If `userinfo.db` doesn't exist in this directory, Cloak will create one
|
`AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`.
|
||||||
automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as
|
|
||||||
/ (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you
|
`DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will
|
||||||
leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so
|
create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`.
|
||||||
and will raise an error. See Issue #13.**
|
This field also has no effect if `AdminUID` isn't a valid UID or is empty.
|
||||||
|
|
||||||
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
|
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
|
||||||
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
|
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
|
||||||
|
|
@ -184,6 +182,8 @@ Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.jso
|
||||||
|
|
||||||
##### Users subject to bandwidth and credit controls
|
##### Users subject to bandwidth and credit controls
|
||||||
|
|
||||||
|
0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db`
|
||||||
|
in `DatabasePath` (Cloak will create this file for you if it didn't already exist).
|
||||||
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
|
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
|
||||||
enter admin mode
|
enter admin mode
|
||||||
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data
|
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"github.com/cbeuw/Cloak/internal/common"
|
"github.com/cbeuw/Cloak/internal/common"
|
||||||
"github.com/cbeuw/Cloak/internal/ecdh"
|
"github.com/cbeuw/Cloak/internal/ecdh"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -26,7 +27,10 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh
|
||||||
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
|
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
|
||||||
+----------+----------------+---------------------+-------------+--------------+--------+------------+
|
+----------+----------------+---------------------+-------------+--------------+--------+------------+
|
||||||
*/
|
*/
|
||||||
ephPv, ephPub, _ := ecdh.GenerateKey(authInfo.WorldState.Rand)
|
ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand)
|
||||||
|
if err != nil {
|
||||||
|
log.Panicf("failed to generate ephemeral key pair: %v", err)
|
||||||
|
}
|
||||||
copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
|
copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
|
||||||
|
|
||||||
plaintext := make([]byte, 48)
|
plaintext := make([]byte, 48)
|
||||||
|
|
@ -40,7 +44,11 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh
|
||||||
plaintext[41] |= UNORDERED_FLAG
|
plaintext[41] |= UNORDERED_FLAG
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(sharedSecret[:], ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey))
|
secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Panicf("error in generating shared secret: %v", err)
|
||||||
|
}
|
||||||
|
copy(sharedSecret[:], secret)
|
||||||
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext)
|
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext)
|
||||||
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:])
|
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:])
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -68,13 +68,11 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) {
|
||||||
return &pub, true
|
return &pub, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte {
|
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) {
|
||||||
var priv, pub, secret *[32]byte
|
var priv, pub *[32]byte
|
||||||
|
|
||||||
priv = privKey.(*[32]byte)
|
priv = privKey.(*[32]byte)
|
||||||
pub = pubKey.(*[32]byte)
|
pub = pubKey.(*[32]byte)
|
||||||
secret = new([32]byte)
|
|
||||||
|
|
||||||
curve25519.ScalarMult(secret, priv, pub)
|
return curve25519.X25519(priv[:], pub[:])
|
||||||
return secret[:]
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -90,11 +90,11 @@ func testECDH(t testing.TB) {
|
||||||
t.Fatalf("Unmarshal does not work")
|
t.Fatalf("Unmarshal does not work")
|
||||||
}
|
}
|
||||||
|
|
||||||
secret1 = GenerateSharedSecret(privKey1, pubKey2)
|
secret1, err = GenerateSharedSecret(privKey1, pubKey2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
secret2 = GenerateSharedSecret(privKey2, pubKey1)
|
secret2, err = GenerateSharedSecret(privKey2, pubKey1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *datagramBufferedPipe) Write(f Frame) (toBeClosed bool, err error) {
|
func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
|
||||||
d.rwCond.L.Lock()
|
d.rwCond.L.Lock()
|
||||||
defer d.rwCond.L.Unlock()
|
defer d.rwCond.L.Unlock()
|
||||||
if d.buf == nil {
|
if d.buf == nil {
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
b := []byte{0x01, 0x02, 0x03}
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
t.Run("simple write", func(t *testing.T) {
|
t.Run("simple write", func(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
_, err := pipe.Write(Frame{Payload: b})
|
_, err := pipe.Write(&Frame{Payload: b})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(
|
t.Error(
|
||||||
"expecting", "nil error",
|
"expecting", "nil error",
|
||||||
|
|
@ -22,7 +22,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
|
|
||||||
t.Run("simple read", func(t *testing.T) {
|
t.Run("simple read", func(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
_, _ = pipe.Write(Frame{Payload: b})
|
_, _ = pipe.Write(&Frame{Payload: b})
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
if n != len(b) {
|
if n != len(b) {
|
||||||
|
|
@ -55,7 +55,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
|
||||||
|
|
||||||
t.Run("writing closing frame", func(t *testing.T) {
|
t.Run("writing closing frame", func(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
toBeClosed, err := pipe.Write(Frame{Closing: closingStream})
|
toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
|
||||||
if !toBeClosed {
|
if !toBeClosed {
|
||||||
t.Error("should be to be closed")
|
t.Error("should be to be closed")
|
||||||
}
|
}
|
||||||
|
|
@ -77,7 +77,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
|
||||||
b := []byte{0x01, 0x02, 0x03}
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(readBlockTime)
|
time.Sleep(readBlockTime)
|
||||||
pipe.Write(Frame{Payload: b})
|
pipe.Write(&Frame{Payload: b})
|
||||||
}()
|
}()
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
|
|
@ -110,7 +110,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
|
||||||
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
func TestDatagramBuffer_CloseThenRead(t *testing.T) {
|
||||||
pipe := NewDatagramBufferedPipe()
|
pipe := NewDatagramBufferedPipe()
|
||||||
b := []byte{0x01, 0x02, 0x03}
|
b := []byte{0x01, 0x02, 0x03}
|
||||||
pipe.Write(Frame{Payload: b})
|
pipe.Write(&Frame{Payload: b})
|
||||||
b2 := make([]byte, len(b))
|
b2 := make([]byte, len(b))
|
||||||
pipe.Close()
|
pipe.Close()
|
||||||
n, err := pipe.Read(b2)
|
n, err := pipe.Read(b2)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func serveEcho(l net.Listener) {
|
func serveEcho(l net.Listener) {
|
||||||
|
|
@ -64,21 +63,20 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
|
||||||
return clientSession, serverSession, paris
|
return clientSession, serverSession, paris
|
||||||
}
|
}
|
||||||
|
|
||||||
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
|
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for _, conn := range conns {
|
for _, conn := range conns {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
testDataLen := rand.Intn(maxMsgLen)
|
testData := make([]byte, msgLen)
|
||||||
testData := make([]byte, testDataLen)
|
|
||||||
rand.Read(testData)
|
rand.Read(testData)
|
||||||
|
|
||||||
n, err := conn.Write(testData)
|
n, err := conn.Write(testData)
|
||||||
if n != testDataLen {
|
if n != msgLen {
|
||||||
t.Fatalf("written only %v, err %v", n, err)
|
t.Fatalf("written only %v, err %v", n, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
recvBuf := make([]byte, testDataLen)
|
recvBuf := make([]byte, msgLen)
|
||||||
_, err = io.ReadFull(conn, recvBuf)
|
_, err = io.ReadFull(conn, recvBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read back: %v", err)
|
t.Fatalf("failed to read back: %v", err)
|
||||||
|
|
@ -96,7 +94,7 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
|
||||||
func TestMultiplex(t *testing.T) {
|
func TestMultiplex(t *testing.T) {
|
||||||
const numStreams = 2000 // -race option limits the number of goroutines to 8192
|
const numStreams = 2000 // -race option limits the number of goroutines to 8192
|
||||||
const numConns = 4
|
const numConns = 4
|
||||||
const maxMsgLen = 16384
|
const msgLen = 16384
|
||||||
|
|
||||||
clientSession, serverSession, _ := makeSessionPair(numConns)
|
clientSession, serverSession, _ := makeSessionPair(numConns)
|
||||||
go serveEcho(serverSession)
|
go serveEcho(serverSession)
|
||||||
|
|
@ -111,15 +109,10 @@ func TestMultiplex(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//test echo
|
//test echo
|
||||||
runEchoTest(t, streams, maxMsgLen)
|
runEchoTest(t, streams, msgLen)
|
||||||
|
|
||||||
assert.Eventuallyf(t, func() bool {
|
assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
|
||||||
return clientSession.streamCount() == numStreams
|
assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
|
||||||
}, time.Second, 10*time.Millisecond, "client stream count is wrong: %v", clientSession.streamCount())
|
|
||||||
|
|
||||||
assert.Eventuallyf(t, func() bool {
|
|
||||||
return serverSession.streamCount() == numStreams
|
|
||||||
}, time.Second, 10*time.Millisecond, "server stream count is wrong: %v", serverSession.streamCount())
|
|
||||||
|
|
||||||
// close one stream
|
// close one stream
|
||||||
closing, streams := streams[0], streams[1:]
|
closing, streams := streams[0], streams[1:]
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Obfser func(*Frame, []byte, int) (int, error)
|
type Obfser func(*Frame, []byte, int) (int, error)
|
||||||
type Deobfser func([]byte) (*Frame, error)
|
type Deobfser func(*Frame, []byte) error
|
||||||
|
|
||||||
var u32 = binary.BigEndian.Uint32
|
var u32 = binary.BigEndian.Uint32
|
||||||
var u64 = binary.BigEndian.Uint64
|
var u64 = binary.BigEndian.Uint64
|
||||||
|
|
@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
|
||||||
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
|
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
|
||||||
// frame header length + minimum data size (i.e. nonce size of salsa20)
|
// frame header length + minimum data size (i.e. nonce size of salsa20)
|
||||||
const minInputLen = frameHeaderLength + salsa20NonceSize
|
const minInputLen = frameHeaderLength + salsa20NonceSize
|
||||||
deobfs := func(in []byte) (*Frame, error) {
|
deobfs := func(f *Frame, in []byte) error {
|
||||||
if len(in) < minInputLen {
|
if len(in) < minInputLen {
|
||||||
return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
|
return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
|
||||||
}
|
}
|
||||||
|
|
||||||
header := in[:frameHeaderLength]
|
header := in[:frameHeaderLength]
|
||||||
|
|
@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
|
||||||
|
|
||||||
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
|
usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
|
||||||
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
|
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
|
||||||
return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
|
return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputPayload []byte
|
var outputPayload []byte
|
||||||
|
|
@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
|
||||||
} else {
|
} else {
|
||||||
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil)
|
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
outputPayload = pldWithOverHead[:usefulPayloadLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := &Frame{
|
f.StreamID = streamID
|
||||||
StreamID: streamID,
|
f.Seq = seq
|
||||||
Seq: seq,
|
f.Closing = closing
|
||||||
Closing: closing,
|
f.Payload = outputPayload
|
||||||
Payload: outputPayload,
|
return nil
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
}
|
}
|
||||||
return deobfs
|
return deobfs
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) {
|
||||||
|
|
||||||
run := func(obfuscator Obfuscator, ct *testing.T) {
|
run := func(obfuscator Obfuscator, ct *testing.T) {
|
||||||
obfsBuf := make([]byte, 512)
|
obfsBuf := make([]byte, 512)
|
||||||
f := &Frame{}
|
_testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
|
||||||
_testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42)))
|
|
||||||
testFrame := _testFrame.Interface().(*Frame)
|
testFrame := _testFrame.Interface().(*Frame)
|
||||||
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0)
|
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resultFrame, err := obfuscator.Deobfs(obfsBuf[:i])
|
var resultFrame Frame
|
||||||
|
err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ct.Error("failed to deobfs ", err)
|
ct.Error("failed to deobfs ", err)
|
||||||
return
|
return
|
||||||
|
|
@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
n, _ := obfs(testFrame, obfsBuf, 0)
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
deobfs := MakeDeobfs(key, payloadCipher)
|
||||||
|
|
||||||
|
frame := new(Frame)
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(obfsBuf[:n])
|
deobfs(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("AES128GCM", func(b *testing.B) {
|
b.Run("AES128GCM", func(b *testing.B) {
|
||||||
|
|
@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
n, _ := obfs(testFrame, obfsBuf, 0)
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
deobfs := MakeDeobfs(key, payloadCipher)
|
||||||
|
|
||||||
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(obfsBuf[:n])
|
deobfs(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("plain", func(b *testing.B) {
|
b.Run("plain", func(b *testing.B) {
|
||||||
|
|
@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
n, _ := obfs(testFrame, obfsBuf, 0)
|
||||||
deobfs := MakeDeobfs(key, nil)
|
deobfs := MakeDeobfs(key, nil)
|
||||||
|
|
||||||
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(obfsBuf[:n])
|
deobfs(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
b.Run("chacha20Poly1305", func(b *testing.B) {
|
b.Run("chacha20Poly1305", func(b *testing.B) {
|
||||||
|
|
@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) {
|
||||||
n, _ := obfs(testFrame, obfsBuf, 0)
|
n, _ := obfs(testFrame, obfsBuf, 0)
|
||||||
deobfs := MakeDeobfs(key, payloadCipher)
|
deobfs := MakeDeobfs(key, payloadCipher)
|
||||||
|
|
||||||
|
frame := new(Frame)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
b.SetBytes(int64(n))
|
b.SetBytes(int64(n))
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
deobfs(obfsBuf[:n])
|
deobfs(frame, obfsBuf[:n])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ type recvBuffer interface {
|
||||||
// when the buffer is empty.
|
// when the buffer is empty.
|
||||||
io.ReadCloser
|
io.ReadCloser
|
||||||
io.WriterTo
|
io.WriterTo
|
||||||
Write(Frame) (toBeClosed bool, err error)
|
Write(*Frame) (toBeClosed bool, err error)
|
||||||
SetReadDeadline(time time.Time)
|
SetReadDeadline(time time.Time)
|
||||||
// SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing
|
// SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing
|
||||||
// has been written for a while. After that duration it should return ErrTimeout
|
// has been written for a while. After that duration it should return ErrTimeout
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,12 @@ type Session struct {
|
||||||
|
|
||||||
// atomic
|
// atomic
|
||||||
activeStreamCount uint32
|
activeStreamCount uint32
|
||||||
streams sync.Map
|
|
||||||
|
streamsM sync.Mutex
|
||||||
|
streams map[uint32]*Stream
|
||||||
|
|
||||||
|
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||||
|
recvFramePool sync.Pool
|
||||||
|
|
||||||
// Switchboard manages all connections to remote
|
// Switchboard manages all connections to remote
|
||||||
sb *switchboard
|
sb *switchboard
|
||||||
|
|
@ -89,6 +94,8 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
SessionConfig: config,
|
SessionConfig: config,
|
||||||
nextStreamID: 1,
|
nextStreamID: 1,
|
||||||
acceptCh: make(chan *Stream, acceptBacklog),
|
acceptCh: make(chan *Stream, acceptBacklog),
|
||||||
|
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
|
||||||
|
streams: map[uint32]*Stream{},
|
||||||
}
|
}
|
||||||
sesh.addrs.Store([]net.Addr{nil, nil})
|
sesh.addrs.Store([]net.Addr{nil, nil})
|
||||||
|
|
||||||
|
|
@ -145,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
|
||||||
return nil, errNoMultiplex
|
return nil, errNoMultiplex
|
||||||
}
|
}
|
||||||
stream := makeStream(sesh, id)
|
stream := makeStream(sesh, id)
|
||||||
sesh.streams.Store(id, stream)
|
sesh.streamsM.Lock()
|
||||||
|
sesh.streams[id] = stream
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
sesh.streamCountIncr()
|
sesh.streamCountIncr()
|
||||||
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
log.Tracef("stream %v of session %v opened", id, sesh.id)
|
||||||
return stream, nil
|
return stream, nil
|
||||||
|
|
@ -165,24 +174,22 @@ func (sesh *Session) Accept() (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
|
// must be holding s.wirtingM on entry
|
||||||
if atomic.SwapUint32(&s.closed, 1) == 1 {
|
if atomic.SwapUint32(&s.closed, 1) == 1 {
|
||||||
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
|
||||||
}
|
}
|
||||||
_ = s.recvBuf.Close() // recvBuf.Close should not return error
|
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
||||||
|
|
||||||
if active {
|
if active {
|
||||||
// Notify remote that this stream is closed
|
// Notify remote that this stream is closed
|
||||||
padding := genRandomPadding()
|
padding := genRandomPadding()
|
||||||
f := &Frame{
|
s.writingFrame.Closing = closingStream
|
||||||
StreamID: s.id,
|
s.writingFrame.Payload = padding
|
||||||
Seq: s.nextSendSeq,
|
|
||||||
Closing: closingStream,
|
|
||||||
Payload: padding,
|
|
||||||
}
|
|
||||||
s.nextSendSeq++
|
|
||||||
|
|
||||||
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
||||||
i, err := sesh.Obfs(f, obfsBuf, 0)
|
|
||||||
|
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
|
||||||
|
s.writingFrame.Seq++
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -190,7 +197,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq)
|
log.Tracef("stream %v actively closed.", s.id)
|
||||||
} else {
|
} else {
|
||||||
log.Tracef("stream %v passively closed", s.id)
|
log.Tracef("stream %v passively closed", s.id)
|
||||||
}
|
}
|
||||||
|
|
@ -198,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
// We set it as nil to signify that the stream id had existed before.
|
// We set it as nil to signify that the stream id had existed before.
|
||||||
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
|
||||||
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
// if the frame it received was from a new stream or a dying stream whose frame arrived late
|
||||||
sesh.streams.Store(s.id, nil)
|
sesh.streamsM.Lock()
|
||||||
|
sesh.streams[s.id] = nil
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if sesh.streamCountDecr() == 0 {
|
if sesh.streamCountDecr() == 0 {
|
||||||
if sesh.Singleplex {
|
if sesh.Singleplex {
|
||||||
return sesh.Close()
|
return sesh.Close()
|
||||||
|
|
@ -214,7 +223,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
|
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
|
||||||
// stream and then writes to the stream buffer
|
// stream and then writes to the stream buffer
|
||||||
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
frame, err := sesh.Deobfs(data)
|
frame := sesh.recvFramePool.Get().(*Frame)
|
||||||
|
defer sesh.recvFramePool.Put(frame)
|
||||||
|
|
||||||
|
err := sesh.Deobfs(frame, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
|
||||||
}
|
}
|
||||||
|
|
@ -224,19 +236,23 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
|
||||||
return sesh.passiveClose()
|
return sesh.passiveClose()
|
||||||
}
|
}
|
||||||
|
|
||||||
newStream := makeStream(sesh, frame.StreamID)
|
sesh.streamsM.Lock()
|
||||||
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream)
|
existingStream, existing := sesh.streams[frame.StreamID]
|
||||||
if existing {
|
if existing {
|
||||||
if existingStreamI == nil {
|
sesh.streamsM.Unlock()
|
||||||
|
if existingStream == nil {
|
||||||
// this is when the stream existed before but has since been closed. We do nothing
|
// this is when the stream existed before but has since been closed. We do nothing
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return existingStreamI.(*Stream).recvFrame(*frame)
|
return existingStream.recvFrame(frame)
|
||||||
} else {
|
} else {
|
||||||
|
newStream := makeStream(sesh, frame.StreamID)
|
||||||
|
sesh.streams[frame.StreamID] = newStream
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
// new stream
|
// new stream
|
||||||
sesh.streamCountIncr()
|
sesh.streamCountIncr()
|
||||||
sesh.acceptCh <- newStream
|
sesh.acceptCh <- newStream
|
||||||
return newStream.recvFrame(*frame)
|
return newStream.recvFrame(frame)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -260,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
|
||||||
}
|
}
|
||||||
sesh.acceptCh <- nil
|
sesh.acceptCh <- nil
|
||||||
|
|
||||||
sesh.streams.Range(func(key, streamI interface{}) bool {
|
sesh.streamsM.Lock()
|
||||||
if streamI == nil {
|
for id, stream := range sesh.streams {
|
||||||
return true
|
if stream == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
stream := streamI.(*Stream)
|
|
||||||
atomic.StoreUint32(&stream.closed, 1)
|
atomic.StoreUint32(&stream.closed, 1)
|
||||||
_ = stream.recvBuf.Close() // will not block
|
_ = stream.getRecvBuf().Close() // will not block
|
||||||
sesh.streams.Delete(key)
|
delete(sesh.streams, id)
|
||||||
sesh.streamCountDecr()
|
sesh.streamCountDecr()
|
||||||
return true
|
}
|
||||||
})
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
if closeSwitchboard {
|
if closeSwitchboard {
|
||||||
sesh.sb.closeAll()
|
sesh.sb.closeAll()
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var seshConfigOrdered = SessionConfig{}
|
var seshConfigs = map[string]SessionConfig{
|
||||||
|
"ordered": {},
|
||||||
var seshConfigUnordered = SessionConfig{
|
"unordered": {Unordered: true},
|
||||||
Unordered: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const testPayloadLen = 1024
|
const testPayloadLen = 1024
|
||||||
|
|
@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionTypes := []struct {
|
encryptionMethods := map[string]Obfuscator{
|
||||||
name string
|
"plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
|
||||||
config SessionConfig
|
"aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
|
||||||
}{
|
"chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
|
||||||
{"ordered",
|
|
||||||
SessionConfig{}},
|
|
||||||
{"unordered",
|
|
||||||
SessionConfig{Unordered: true}},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptionMethods := []struct {
|
for seshType, seshConfig := range seshConfigs {
|
||||||
name string
|
seshConfig := seshConfig
|
||||||
obfuscator Obfuscator
|
t.Run(seshType, func(t *testing.T) {
|
||||||
}{
|
for method, obfuscator := range encryptionMethods {
|
||||||
{
|
obfuscator := obfuscator
|
||||||
"plain",
|
t.Run(method, func(t *testing.T) {
|
||||||
MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
|
seshConfig.Obfuscator = obfuscator
|
||||||
},
|
sesh := MakeSession(0, seshConfig)
|
||||||
{
|
|
||||||
"aes-gcm",
|
|
||||||
MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"chacha20-poly1305",
|
|
||||||
MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, st := range sessionTypes {
|
|
||||||
t.Run(st.name, func(t *testing.T) {
|
|
||||||
for _, em := range encryptionMethods {
|
|
||||||
t.Run(em.name, func(t *testing.T) {
|
|
||||||
st.config.Obfuscator = em.obfuscator
|
|
||||||
sesh := MakeSession(0, st.config)
|
|
||||||
n, err := sesh.Obfs(f, obfsBuf, 0)
|
n, err := sesh.Obfs(f, obfsBuf, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
|
@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
seshConfig := seshConfigs["ordered"]
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
f1 := &Frame{
|
f1 := &Frame{
|
||||||
1,
|
1,
|
||||||
|
|
@ -131,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
t.Fatalf("receiving normal frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
_, ok := sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
|
_, ok := sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("failed to fetch stream 1 after receiving it")
|
t.Fatal("failed to fetch stream 1 after receiving it")
|
||||||
}
|
}
|
||||||
|
|
@ -151,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
t.Fatalf("receiving normal frame for stream 2: %v", err)
|
||||||
}
|
}
|
||||||
s2I, ok := sesh.streams.Load(f2.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s2I == nil || !ok {
|
s2M, ok := sesh.streams[f2.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s2M == nil || !ok {
|
||||||
t.Fatal("failed to fetch stream 2 after receiving it")
|
t.Fatal("failed to fetch stream 2 after receiving it")
|
||||||
}
|
}
|
||||||
if sesh.streamCount() != 2 {
|
if sesh.streamCount() != 2 {
|
||||||
|
|
@ -171,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
s1I, _ := sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s1I != nil {
|
s1M, _ := sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s1M != nil {
|
||||||
t.Fatal("stream 1 still exist after receiving stream close")
|
t.Fatal("stream 1 still exist after receiving stream close")
|
||||||
}
|
}
|
||||||
s1, _ := sesh.Accept()
|
s1, _ := sesh.Accept()
|
||||||
|
|
@ -198,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
t.Fatalf("receiving stream closing frame for stream 1 %v", err)
|
||||||
}
|
}
|
||||||
s1I, _ = sesh.streams.Load(f1.StreamID)
|
sesh.streamsM.Lock()
|
||||||
if s1I != nil {
|
s1M, _ = sesh.streams[f1.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
if s1M != nil {
|
||||||
t.Error("stream 1 exists after receiving stream close for the second time")
|
t.Error("stream 1 exists after receiving stream close for the second time")
|
||||||
}
|
}
|
||||||
streamCount := sesh.streamCount()
|
streamCount := sesh.streamCount()
|
||||||
|
|
@ -245,8 +234,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
seshConfig := seshConfigs["ordered"]
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
// receive stream 1 closing first
|
// receive stream 1 closing first
|
||||||
f1CloseStream := &Frame{
|
f1CloseStream := &Frame{
|
||||||
|
|
@ -260,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
|
||||||
}
|
}
|
||||||
_, ok := sesh.streams.Load(f1CloseStream.StreamID)
|
sesh.streamsM.Lock()
|
||||||
|
_, ok := sesh.streams[f1CloseStream.StreamID]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("stream 1 doesn't exist")
|
t.Fatal("stream 1 doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
@ -300,8 +293,12 @@ func TestParallelStreams(t *testing.T) {
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
for seshType, seshConfig := range seshConfigs {
|
||||||
|
seshConfig := seshConfig
|
||||||
|
t.Run(seshType, func(t *testing.T) {
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
numStreams := acceptBacklog
|
numStreams := acceptBacklog
|
||||||
seqs := make([]*uint64, numStreams)
|
seqs := make([]*uint64, numStreams)
|
||||||
|
|
@ -347,24 +344,27 @@ func TestParallelStreams(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
sc := int(sesh.streamCount())
|
sc := int(sesh.streamCount())
|
||||||
var count int
|
var count int
|
||||||
sesh.streams.Range(func(_, s interface{}) bool {
|
sesh.streamsM.Lock()
|
||||||
|
for _, s := range sesh.streams {
|
||||||
if s != nil {
|
if s != nil {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
})
|
sesh.streamsM.Unlock()
|
||||||
if sc != count {
|
if sc != count {
|
||||||
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStream_SetReadDeadline(t *testing.T) {
|
func TestStream_SetReadDeadline(t *testing.T) {
|
||||||
var sessionKey [32]byte
|
for seshType, seshConfig := range seshConfigs {
|
||||||
rand.Read(sessionKey[:])
|
seshConfig := seshConfig
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
t.Run(seshType, func(t *testing.T) {
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
sesh.AddConnection(connutil.Discard())
|
||||||
|
|
||||||
testReadDeadline := func(sesh *Session) {
|
|
||||||
t.Run("read after deadline set", func(t *testing.T) {
|
t.Run("read after deadline set", func(t *testing.T) {
|
||||||
stream, _ := sesh.OpenStream()
|
stream, _ := sesh.OpenStream()
|
||||||
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||||
|
|
@ -392,27 +392,27 @@ func TestStream_SetReadDeadline(t *testing.T) {
|
||||||
t.Error("Read did not unblock after deadline has passed")
|
t.Error("Read did not unblock after deadline has passed")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
|
||||||
sesh.AddConnection(connutil.Discard())
|
|
||||||
testReadDeadline(sesh)
|
|
||||||
sesh = MakeSession(0, seshConfigUnordered)
|
|
||||||
sesh.AddConnection(connutil.Discard())
|
|
||||||
testReadDeadline(sesh)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSession_timeoutAfter(t *testing.T) {
|
func TestSession_timeoutAfter(t *testing.T) {
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
|
||||||
seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond
|
for seshType, seshConfig := range seshConfigs {
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
seshConfig := seshConfig
|
||||||
|
t.Run(seshType, func(t *testing.T) {
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
seshConfig.InactivityTimeout = 100 * time.Millisecond
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
return sesh.IsClosed()
|
return sesh.IsClosed()
|
||||||
}, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out")
|
}, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
|
|
@ -424,47 +424,73 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
|
||||||
0,
|
0,
|
||||||
testPayload,
|
testPayload,
|
||||||
}
|
}
|
||||||
obfsBuf := make([]byte, obfsBufLen)
|
|
||||||
|
|
||||||
var sessionKey [32]byte
|
var sessionKey [32]byte
|
||||||
rand.Read(sessionKey[:])
|
rand.Read(sessionKey[:])
|
||||||
|
|
||||||
b.Run("plain", func(b *testing.B) {
|
table := map[string]byte{
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
|
"plain": EncryptionMethodPlain,
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
"aes-gcm": EncryptionMethodAESGCM,
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic
|
||||||
|
for name, ep := range table {
|
||||||
|
ep := ep
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
seshConfig := seshConfigs["ordered"]
|
||||||
|
obfuscator, _ := MakeObfuscator(ep, sessionKey)
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
|
||||||
|
binaryFrames := [maxIter][]byte{}
|
||||||
|
for i := 0; i < maxIter; i++ {
|
||||||
|
obfsBuf := make([]byte, obfsBufLen)
|
||||||
n, _ := sesh.Obfs(f, obfsBuf, 0)
|
n, _ := sesh.Obfs(f, obfsBuf, 0)
|
||||||
|
binaryFrames[i] = obfsBuf[:n]
|
||||||
|
f.Seq++
|
||||||
|
}
|
||||||
|
|
||||||
b.SetBytes(int64(len(f.Payload)))
|
b.SetBytes(int64(len(f.Payload)))
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
sesh.recvDataFromRemote(obfsBuf[:n])
|
sesh.recvDataFromRemote(binaryFrames[i])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
b.Run("aes-gcm", func(b *testing.B) {
|
func BenchmarkMultiStreamWrite(b *testing.B) {
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey)
|
var sessionKey [32]byte
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
rand.Read(sessionKey[:])
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
|
||||||
n, _ := sesh.Obfs(f, obfsBuf, 0)
|
|
||||||
|
|
||||||
b.SetBytes(int64(len(f.Payload)))
|
table := map[string]byte{
|
||||||
|
"plain": EncryptionMethodPlain,
|
||||||
|
"aes-gcm": EncryptionMethodAESGCM,
|
||||||
|
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
|
||||||
|
}
|
||||||
|
|
||||||
|
testPayload := make([]byte, testPayloadLen)
|
||||||
|
|
||||||
|
for name, ep := range table {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
for seshType, seshConfig := range seshConfigs {
|
||||||
|
seshConfig := seshConfig
|
||||||
|
b.Run(seshType, func(b *testing.B) {
|
||||||
|
obfuscator, _ := MakeObfuscator(ep, sessionKey)
|
||||||
|
seshConfig.Obfuscator = obfuscator
|
||||||
|
sesh := MakeSession(0, seshConfig)
|
||||||
|
sesh.AddConnection(connutil.Discard())
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
b.SetBytes(testPayloadLen)
|
||||||
sesh.recvDataFromRemote(obfsBuf[:n])
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
stream, _ := sesh.OpenStream()
|
||||||
|
for pb.Next() {
|
||||||
|
stream.Write(testPayload)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
})
|
||||||
b.Run("chacha20-poly1305", func(b *testing.B) {
|
|
||||||
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
|
|
||||||
seshConfigOrdered.Obfuscator = obfuscator
|
|
||||||
sesh := MakeSession(0, seshConfigOrdered)
|
|
||||||
n, _ := sesh.Obfs(f, obfsBuf, 0)
|
|
||||||
|
|
||||||
b.SetBytes(int64(len(f.Payload)))
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
sesh.recvDataFromRemote(obfsBuf[:n])
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,21 +23,20 @@ type Stream struct {
|
||||||
|
|
||||||
session *Session
|
session *Session
|
||||||
|
|
||||||
|
allocIdempot sync.Once
|
||||||
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
|
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
|
||||||
// been read by the consumer through Read or WriteTo
|
// been read by the consumer through Read or WriteTo. Lazily allocated
|
||||||
recvBuf recvBuffer
|
recvBuf recvBuffer
|
||||||
|
|
||||||
writingM sync.Mutex
|
writingM sync.Mutex
|
||||||
nextSendSeq uint64
|
writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom
|
||||||
|
|
||||||
// atomic
|
// atomic
|
||||||
closed uint32
|
closed uint32
|
||||||
|
|
||||||
// lazy allocation for obfsBuf. This is desirable because obfsBuf is only used when data is sent from
|
// obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
|
||||||
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
|
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
|
||||||
// memory
|
// memory
|
||||||
allocIdempot sync.Once
|
|
||||||
// obfuscation happens in this buffer
|
|
||||||
obfsBuf []byte
|
obfsBuf []byte
|
||||||
|
|
||||||
// When we want order guarantee (i.e. session.Unordered is false),
|
// When we want order guarantee (i.e. session.Unordered is false),
|
||||||
|
|
@ -52,17 +51,14 @@ type Stream struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeStream(sesh *Session, id uint32) *Stream {
|
func makeStream(sesh *Session, id uint32) *Stream {
|
||||||
var recvBuf recvBuffer
|
|
||||||
if sesh.Unordered {
|
|
||||||
recvBuf = NewDatagramBufferedPipe()
|
|
||||||
} else {
|
|
||||||
recvBuf = NewStreamBuffer()
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := &Stream{
|
stream := &Stream{
|
||||||
id: id,
|
id: id,
|
||||||
session: sesh,
|
session: sesh,
|
||||||
recvBuf: recvBuf,
|
writingFrame: Frame{
|
||||||
|
StreamID: id,
|
||||||
|
Seq: 0,
|
||||||
|
Closing: closingNothing,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
@ -70,9 +66,20 @@ func makeStream(sesh *Session, id uint32) *Stream {
|
||||||
|
|
||||||
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
|
||||||
|
|
||||||
|
func (s *Stream) getRecvBuf() recvBuffer {
|
||||||
|
s.allocIdempot.Do(func() {
|
||||||
|
if s.session.Unordered {
|
||||||
|
s.recvBuf = NewDatagramBufferedPipe()
|
||||||
|
} else {
|
||||||
|
s.recvBuf = NewStreamBuffer()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return s.recvBuf
|
||||||
|
}
|
||||||
|
|
||||||
// receive a readily deobfuscated Frame so its payload can later be Read
|
// receive a readily deobfuscated Frame so its payload can later be Read
|
||||||
func (s *Stream) recvFrame(frame Frame) error {
|
func (s *Stream) recvFrame(frame *Frame) error {
|
||||||
toBeClosed, err := s.recvBuf.Write(frame)
|
toBeClosed, err := s.getRecvBuf().Write(frame)
|
||||||
if toBeClosed {
|
if toBeClosed {
|
||||||
err = s.passiveClose()
|
err = s.passiveClose()
|
||||||
if errors.Is(err, errRepeatStreamClosing) {
|
if errors.Is(err, errRepeatStreamClosing) {
|
||||||
|
|
@ -91,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err = s.recvBuf.Read(buf)
|
n, err = s.getRecvBuf().Read(buf)
|
||||||
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return n, ErrBrokenStream
|
return n, ErrBrokenStream
|
||||||
|
|
@ -102,7 +109,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
|
||||||
// WriteTo continuously write data Stream has received into the writer w.
|
// WriteTo continuously write data Stream has received into the writer w.
|
||||||
func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
||||||
// will keep writing until the underlying buffer is closed
|
// will keep writing until the underlying buffer is closed
|
||||||
n, err := s.recvBuf.WriteTo(w)
|
n, err := s.getRecvBuf().WriteTo(w)
|
||||||
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
log.Tracef("%v read from stream %v with err %v", n, s.id, err)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return n, ErrBrokenStream
|
return n, ErrBrokenStream
|
||||||
|
|
@ -110,15 +117,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error {
|
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
|
||||||
var cipherTextLen int
|
var cipherTextLen int
|
||||||
cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf)
|
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
|
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
|
||||||
log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errBrokenSwitchboard {
|
if err == errBrokenSwitchboard {
|
||||||
s.session.SetTerminalMsg(err.Error())
|
s.session.SetTerminalMsg(err.Error())
|
||||||
|
|
@ -154,14 +160,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
||||||
}
|
}
|
||||||
f := &Frame{
|
s.writingFrame.Payload = framePayload
|
||||||
StreamID: s.id,
|
err = s.obfuscateAndSend(0)
|
||||||
Seq: s.nextSendSeq,
|
s.writingFrame.Seq++
|
||||||
Closing: closingNothing,
|
|
||||||
Payload: framePayload,
|
|
||||||
}
|
|
||||||
s.nextSendSeq++
|
|
||||||
err = s.obfuscateAndSend(f, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -193,14 +194,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writingM.Lock()
|
s.writingM.Lock()
|
||||||
f := &Frame{
|
s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
|
||||||
StreamID: s.id,
|
err = s.obfuscateAndSend(frameHeaderLength)
|
||||||
Seq: s.nextSendSeq,
|
s.writingFrame.Seq++
|
||||||
Closing: closingNothing,
|
|
||||||
Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read],
|
|
||||||
}
|
|
||||||
s.nextSendSeq++
|
|
||||||
err = s.obfuscateAndSend(f, frameHeaderLength)
|
|
||||||
s.writingM.Unlock()
|
s.writingM.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -225,8 +221,8 @@ func (s *Stream) Close() error {
|
||||||
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
|
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
|
||||||
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
|
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
|
||||||
|
|
||||||
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) }
|
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) }
|
||||||
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil }
|
func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil }
|
||||||
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
|
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
|
||||||
|
|
||||||
var errNotImplemented = errors.New("Not implemented")
|
var errNotImplemented = errors.New("Not implemented")
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ func NewStreamBuffer() *streamBuffer {
|
||||||
return sb
|
return sb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) {
|
||||||
sb.recvM.Lock()
|
sb.recvM.Lock()
|
||||||
defer sb.recvM.Unlock()
|
defer sb.recvM.Unlock()
|
||||||
// when there'fs no ooo packages in heap and we receive the next package in order
|
// when there'fs no ooo packages in heap and we receive the next package in order
|
||||||
|
|
@ -81,10 +81,11 @@ func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
|
||||||
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
|
||||||
}
|
}
|
||||||
|
|
||||||
heap.Push(&sb.sh, &f)
|
saved := *f
|
||||||
|
heap.Push(&sb.sh, &saved)
|
||||||
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
// Keep popping from the heap until empty or to the point that the wanted seq was not received
|
||||||
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
|
||||||
f = *heap.Pop(&sb.sh).(*Frame)
|
f = heap.Pop(&sb.sh).(*Frame)
|
||||||
if f.Closing != closingNothing {
|
if f.Closing != closingNothing {
|
||||||
return true, nil
|
return true, nil
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,10 @@ func TestRecvNewFrame(t *testing.T) {
|
||||||
for _, n := range set {
|
for _, n := range set {
|
||||||
bu64 := make([]byte, 8)
|
bu64 := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(bu64, n)
|
binary.BigEndian.PutUint64(bu64, n)
|
||||||
frame := Frame{
|
sb.Write(&Frame{
|
||||||
Seq: n,
|
Seq: n,
|
||||||
Payload: bu64,
|
Payload: bu64,
|
||||||
}
|
})
|
||||||
sb.Write(frame)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sortedResult []uint64
|
var sortedResult []uint64
|
||||||
|
|
@ -80,7 +79,7 @@ func TestStreamBuffer_RecvThenClose(t *testing.T) {
|
||||||
Closing: 0,
|
Closing: 0,
|
||||||
Payload: testData,
|
Payload: testData,
|
||||||
}
|
}
|
||||||
sb.Write(testFrame)
|
sb.Write(&testFrame)
|
||||||
sb.Close()
|
sb.Close()
|
||||||
|
|
||||||
readBuf := make([]byte, testDataLen)
|
readBuf := make([]byte, testDataLen)
|
||||||
|
|
|
||||||
|
|
@ -151,19 +151,31 @@ func TestStream_Close(t *testing.T) {
|
||||||
t.Error("failed to accept stream", err)
|
t.Error("failed to accept stream", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we read something to wait for the test frame to reach our recvBuffer.
|
||||||
|
// if it's empty by the point we call stream.Close(), the incoming
|
||||||
|
// frame will be dropped
|
||||||
|
readBuf := make([]byte, len(testPayload))
|
||||||
|
_, err = io.ReadFull(stream, readBuf[:1])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("can't read any data before active closing")
|
||||||
|
}
|
||||||
|
|
||||||
err = stream.Close()
|
err = stream.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("failed to actively close stream", err)
|
t.Error("failed to actively close stream", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil {
|
sesh.streamsM.Lock()
|
||||||
|
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
t.Error("stream still exists")
|
t.Error("stream still exists")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
|
||||||
readBuf := make([]byte, len(testPayload))
|
_, err = io.ReadFull(stream, readBuf[1:])
|
||||||
_, err = io.ReadFull(stream, readBuf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("can't read residual data %v", err)
|
t.Errorf("can't read residual data %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -233,8 +245,10 @@ func TestStream_Close(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
sI, _ := sesh.streams.Load(stream.(*Stream).id)
|
sesh.streamsM.Lock()
|
||||||
return sI == nil
|
s, _ := sesh.streams[stream.(*Stream).id]
|
||||||
|
sesh.streamsM.Unlock()
|
||||||
|
return s == nil
|
||||||
}, time.Second, 10*time.Millisecond, "streams still exists")
|
}, time.Second, 10*time.Millisecond, "streams still exists")
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,13 @@ func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fr
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub))
|
var sharedSecret []byte
|
||||||
|
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(fragments.sharedSecret[:], sharedSecret)
|
||||||
var keyShare []byte
|
var keyShare []byte
|
||||||
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
|
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -143,10 +143,15 @@ func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, er
|
||||||
err = errors.New("command & control mode not implemented")
|
err = errors.New("command & control mode not implemented")
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
manager, err := usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
|
var manager usermanager.UserManager
|
||||||
|
if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" {
|
||||||
|
manager = &usermanager.Voidmanager{}
|
||||||
|
} else {
|
||||||
|
manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sta, err
|
return sta, err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
sta.Panel = MakeUserPanel(manager)
|
sta.Panel = MakeUserPanel(manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ const (
|
||||||
|
|
||||||
var ErrUserNotFound = errors.New("UID does not correspond to a user")
|
var ErrUserNotFound = errors.New("UID does not correspond to a user")
|
||||||
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
|
var ErrSessionsCapReached = errors.New("Sessions cap has reached")
|
||||||
|
var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified")
|
||||||
|
|
||||||
var ErrNoUpCredit = errors.New("No upload credit left")
|
var ErrNoUpCredit = errors.New("No upload credit left")
|
||||||
var ErrNoDownCredit = errors.New("No download credit left")
|
var ErrNoDownCredit = errors.New("No download credit left")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
package usermanager
|
||||||
|
|
||||||
|
type Voidmanager struct{}
|
||||||
|
|
||||||
|
func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) {
|
||||||
|
return 0, 0, ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error {
|
||||||
|
return ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) {
|
||||||
|
return nil, ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) {
|
||||||
|
return nil, ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) {
|
||||||
|
return UserInfo{}, ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) WriteUserInfo(info UserInfo) error {
|
||||||
|
return ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Voidmanager) DeleteUser(UID []byte) error {
|
||||||
|
return ErrMangerIsVoid
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
package usermanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var v = &Voidmanager{}
|
||||||
|
|
||||||
|
func Test_Voidmanager_AuthenticateUser(t *testing.T) {
|
||||||
|
_, _, err := v.AuthenticateUser([]byte{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_AuthoriseNewSession(t *testing.T) {
|
||||||
|
err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_DeleteUser(t *testing.T) {
|
||||||
|
err := v.DeleteUser([]byte{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_GetUserInfo(t *testing.T) {
|
||||||
|
_, err := v.GetUserInfo([]byte{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_ListAllUsers(t *testing.T) {
|
||||||
|
_, err := v.ListAllUsers()
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_UploadStatus(t *testing.T) {
|
||||||
|
_, err := v.UploadStatus([]StatusUpdate{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Voidmanager_WriteUserInfo(t *testing.T) {
|
||||||
|
err := v.WriteUserInfo(UserInfo{})
|
||||||
|
assert.Equal(t, ErrMangerIsVoid, err)
|
||||||
|
}
|
||||||
|
|
@ -185,6 +185,9 @@ func (panel *userPanel) commitUpdate() error {
|
||||||
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
|
panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
|
||||||
panel.usageUpdateQueueM.Unlock()
|
panel.usageUpdateQueueM.Unlock()
|
||||||
|
|
||||||
|
if len(statuses) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
responses, err := panel.Manager.UploadStatus(statuses)
|
responses, err := panel.Manager.UploadStatus(statuses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,13 @@ func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fra
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub))
|
var sharedSecret []byte
|
||||||
|
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(fragments.sharedSecret[:], sharedSecret)
|
||||||
|
|
||||||
if len(hidden[32:]) != 64 {
|
if len(hidden[32:]) != 64 {
|
||||||
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))
|
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,8 @@ import (
|
||||||
"github.com/cbeuw/connutil"
|
"github.com/cbeuw/connutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -24,8 +22,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const numConns = 200 // -race option limits the number of goroutines to 8192
|
const numConns = 200 // -race option limits the number of goroutines to 8192
|
||||||
const delayBeforeTestingConnClose = 500 * time.Millisecond
|
|
||||||
const connCloseRetries = 3
|
|
||||||
|
|
||||||
func serveTCPEcho(l net.Listener) {
|
func serveTCPEcho(l net.Listener) {
|
||||||
for {
|
for {
|
||||||
|
|
@ -137,15 +133,13 @@ func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState)
|
||||||
return lcl, rmt, auth
|
return lcl, rmt, auth
|
||||||
}
|
}
|
||||||
|
|
||||||
func basicServerState(ws common.WorldState, db *os.File) *server.State {
|
func basicServerState(ws common.WorldState) *server.State {
|
||||||
var serverConfig = server.RawConfig{
|
var serverConfig = server.RawConfig{
|
||||||
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
|
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
|
||||||
BindAddr: []string{"fake.com:9999"},
|
BindAddr: []string{"fake.com:9999"},
|
||||||
BypassUID: [][]byte{bypassUID[:]},
|
BypassUID: [][]byte{bypassUID[:]},
|
||||||
RedirAddr: "fake.com:9999",
|
RedirAddr: "fake.com:9999",
|
||||||
PrivateKey: privateKey,
|
PrivateKey: privateKey,
|
||||||
AdminUID: nil,
|
|
||||||
DatabasePath: db.Name(),
|
|
||||||
KeepAlive: 15,
|
KeepAlive: 15,
|
||||||
CncMode: false,
|
CncMode: false,
|
||||||
}
|
}
|
||||||
|
|
@ -258,13 +252,11 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDP(t *testing.T) {
|
func TestUDP(t *testing.T) {
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
|
||||||
defer os.Remove(tmpDB.Name())
|
|
||||||
log.SetLevel(log.ErrorLevel)
|
log.SetLevel(log.ErrorLevel)
|
||||||
|
|
||||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||||
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
|
||||||
sta := basicServerState(worldState, tmpDB)
|
sta := basicServerState(worldState)
|
||||||
|
|
||||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -319,9 +311,7 @@ func TestTCPSingleplex(t *testing.T) {
|
||||||
log.SetLevel(log.ErrorLevel)
|
log.SetLevel(log.ErrorLevel)
|
||||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||||
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
sta := basicServerState(worldState)
|
||||||
defer os.Remove(tmpDB.Name())
|
|
||||||
sta := basicServerState(worldState, tmpDB)
|
|
||||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
@ -381,9 +371,7 @@ func TestTCPMultiplex(t *testing.T) {
|
||||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||||
|
|
||||||
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
sta := basicServerState(worldState)
|
||||||
defer os.Remove(tmpDB.Name())
|
|
||||||
sta := basicServerState(worldState, tmpDB)
|
|
||||||
|
|
||||||
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
|
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -456,11 +444,8 @@ func TestClosingStreamsFromProxy(t *testing.T) {
|
||||||
clientConfig := clientConfig
|
clientConfig := clientConfig
|
||||||
clientConfigName := clientConfigName
|
clientConfigName := clientConfigName
|
||||||
t.Run(clientConfigName, func(t *testing.T) {
|
t.Run(clientConfigName, func(t *testing.T) {
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
|
||||||
defer os.Remove(tmpDB.Name())
|
|
||||||
|
|
||||||
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
|
||||||
sta := basicServerState(worldState, tmpDB)
|
sta := basicServerState(worldState)
|
||||||
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
@ -519,12 +504,10 @@ func TestClosingStreamsFromProxy(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
|
|
||||||
defer os.Remove(tmpDB.Name())
|
|
||||||
log.SetLevel(log.ErrorLevel)
|
log.SetLevel(log.ErrorLevel)
|
||||||
worldState := common.WorldOfTime(time.Unix(10, 0))
|
worldState := common.WorldOfTime(time.Unix(10, 0))
|
||||||
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
|
||||||
sta := basicServerState(worldState, tmpDB)
|
sta := basicServerState(worldState)
|
||||||
const bufSize = 16 * 1024
|
const bufSize = 16 * 1024
|
||||||
|
|
||||||
encryptionMethods := map[string]byte{
|
encryptionMethods := map[string]byte{
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
go get github.com/mitchellh/gox
|
go get github.com/mitchellh/gox
|
||||||
|
|
||||||
mkdir -p release
|
mkdir -p release
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue