Merge branch 'master' into notsure2

This commit is contained in:
notsure2 2020-12-22 23:36:07 +02:00
commit 9d79842536
26 changed files with 480 additions and 347 deletions

View File

@ -13,10 +13,10 @@ jobs:
- name: Build - name: Build
run: | run: |
export PATH=${PATH}:`go env GOPATH`/bin export PATH=${PATH}:`go env GOPATH`/bin
v=${{ github.ref }} ./release.sh v=${GITHUB_REF#refs/*/} ./release.sh
- name: Release - name: Release
uses: softprops/action-gh-release@v1 uses: softprops/action-gh-release@v1
with: with:
files: ./release/ck-* files: release/*
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -103,15 +103,13 @@ Example:
`PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64. `PrivateKey` is the static curve25519 Diffie-Hellman private key encoded in base64.
`AdminUID` is the UID of the admin user in base64.
`BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions `BypassUID` is a list of UIDs that are authorised without any bandwidth or credit limit restrictions
`DatabasePath` is the path to `userinfo.db`. If `userinfo.db` doesn't exist in this directory, Cloak will create one `AdminUID` is the UID of the admin user in base64. You can leave this empty if you only ever add users to `BypassUID`.
automatically. **If Cloak is started as a Shadowsocks plugin and Shadowsocks is started with its working directory as
/ (e.g. starting ss-server with systemctl), you need to set this field as an absolute path to a desired folder. If you `DatabasePath` is the path to `userinfo.db`, which is used to store user usage information and restrictions. Cloak will
leave it as default then Cloak will attempt to create userinfo.db under /, which it doesn't have the permission to do so create the file automatically if it doesn't exist. You can leave this empty if you only ever add users to `BypassUID`.
and will raise an error. See Issue #13.** This field also has no effect if `AdminUID` isn't a valid UID or is empty.
`KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the `KeepAlive` is the number of seconds to tell the OS to wait after no activity before sending TCP KeepAlive probes to the
upstream proxy server. Zero or negative value disables it. Default is 0 (disabled). upstream proxy server. Zero or negative value disables it. Default is 0 (disabled).
@ -184,6 +182,8 @@ Run `ck-server -uid` and add the UID into the `BypassUID` field in `ckserver.jso
##### Users subject to bandwidth and credit controls ##### Users subject to bandwidth and credit controls
0. First make sure you have `AdminUID` generated and set in `ckserver.json`, along with a path to `userinfo.db`
in `DatabasePath` (Cloak will create this file for you if it didn't already exist).
1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to 1. On your client, run `ck-client -s <IP of the server> -l <A local port> -a <AdminUID> -c <path-to-ckclient.json>` to
enter admin mode enter admin mode
2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data 2. Visit https://cbeuw.github.io/Cloak-panel (Note: this is a pure-js static site, there is no backend and all data

View File

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"github.com/cbeuw/Cloak/internal/common" "github.com/cbeuw/Cloak/internal/common"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
log "github.com/sirupsen/logrus"
) )
const ( const (
@ -26,7 +27,10 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh
| 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes | | 16 bytes | 12 bytes | 1 byte | 8 bytes | 4 bytes | 1 byte | 6 bytes |
+----------+----------------+---------------------+-------------+--------------+--------+------------+ +----------+----------------+---------------------+-------------+--------------+--------+------------+
*/ */
ephPv, ephPub, _ := ecdh.GenerateKey(authInfo.WorldState.Rand) ephPv, ephPub, err := ecdh.GenerateKey(authInfo.WorldState.Rand)
if err != nil {
log.Panicf("failed to generate ephemeral key pair: %v", err)
}
copy(ret.randPubKey[:], ecdh.Marshal(ephPub)) copy(ret.randPubKey[:], ecdh.Marshal(ephPub))
plaintext := make([]byte, 48) plaintext := make([]byte, 48)
@ -40,7 +44,11 @@ func makeAuthenticationPayload(authInfo AuthInfo) (ret authenticationPayload, sh
plaintext[41] |= UNORDERED_FLAG plaintext[41] |= UNORDERED_FLAG
} }
copy(sharedSecret[:], ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)) secret, err := ecdh.GenerateSharedSecret(ephPv, authInfo.ServerPubKey)
if err != nil {
log.Panicf("error in generating shared secret: %v", err)
}
copy(sharedSecret[:], secret)
ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext) ciphertextWithTag, _ := common.AESGCMEncrypt(ret.randPubKey[:12], sharedSecret[:], plaintext)
copy(ret.ciphertextWithTag[:], ciphertextWithTag[:]) copy(ret.ciphertextWithTag[:], ciphertextWithTag[:])
return return

View File

@ -68,13 +68,11 @@ func Unmarshal(data []byte) (crypto.PublicKey, bool) {
return &pub, true return &pub, true
} }
func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) []byte { func GenerateSharedSecret(privKey crypto.PrivateKey, pubKey crypto.PublicKey) ([]byte, error) {
var priv, pub, secret *[32]byte var priv, pub *[32]byte
priv = privKey.(*[32]byte) priv = privKey.(*[32]byte)
pub = pubKey.(*[32]byte) pub = pubKey.(*[32]byte)
secret = new([32]byte)
curve25519.ScalarMult(secret, priv, pub) return curve25519.X25519(priv[:], pub[:])
return secret[:]
} }

View File

@ -90,11 +90,11 @@ func testECDH(t testing.TB) {
t.Fatalf("Unmarshal does not work") t.Fatalf("Unmarshal does not work")
} }
secret1 = GenerateSharedSecret(privKey1, pubKey2) secret1, err = GenerateSharedSecret(privKey1, pubKey2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
secret2 = GenerateSharedSecret(privKey2, pubKey1) secret2, err = GenerateSharedSecret(privKey2, pubKey1)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }

View File

@ -112,7 +112,7 @@ func (d *datagramBufferedPipe) WriteTo(w io.Writer) (n int64, err error) {
} }
} }
func (d *datagramBufferedPipe) Write(f Frame) (toBeClosed bool, err error) { func (d *datagramBufferedPipe) Write(f *Frame) (toBeClosed bool, err error) {
d.rwCond.L.Lock() d.rwCond.L.Lock()
defer d.rwCond.L.Unlock() defer d.rwCond.L.Unlock()
if d.buf == nil { if d.buf == nil {

View File

@ -10,7 +10,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
b := []byte{0x01, 0x02, 0x03} b := []byte{0x01, 0x02, 0x03}
t.Run("simple write", func(t *testing.T) { t.Run("simple write", func(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
_, err := pipe.Write(Frame{Payload: b}) _, err := pipe.Write(&Frame{Payload: b})
if err != nil { if err != nil {
t.Error( t.Error(
"expecting", "nil error", "expecting", "nil error",
@ -22,7 +22,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("simple read", func(t *testing.T) { t.Run("simple read", func(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
_, _ = pipe.Write(Frame{Payload: b}) _, _ = pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
if n != len(b) { if n != len(b) {
@ -55,7 +55,7 @@ func TestDatagramBuffer_RW(t *testing.T) {
t.Run("writing closing frame", func(t *testing.T) { t.Run("writing closing frame", func(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
toBeClosed, err := pipe.Write(Frame{Closing: closingStream}) toBeClosed, err := pipe.Write(&Frame{Closing: closingStream})
if !toBeClosed { if !toBeClosed {
t.Error("should be to be closed") t.Error("should be to be closed")
} }
@ -77,7 +77,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
b := []byte{0x01, 0x02, 0x03} b := []byte{0x01, 0x02, 0x03}
go func() { go func() {
time.Sleep(readBlockTime) time.Sleep(readBlockTime)
pipe.Write(Frame{Payload: b}) pipe.Write(&Frame{Payload: b})
}() }()
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
n, err := pipe.Read(b2) n, err := pipe.Read(b2)
@ -110,7 +110,7 @@ func TestDatagramBuffer_BlockingRead(t *testing.T) {
func TestDatagramBuffer_CloseThenRead(t *testing.T) { func TestDatagramBuffer_CloseThenRead(t *testing.T) {
pipe := NewDatagramBufferedPipe() pipe := NewDatagramBufferedPipe()
b := []byte{0x01, 0x02, 0x03} b := []byte{0x01, 0x02, 0x03}
pipe.Write(Frame{Payload: b}) pipe.Write(&Frame{Payload: b})
b2 := make([]byte, len(b)) b2 := make([]byte, len(b))
pipe.Close() pipe.Close()
n, err := pipe.Read(b2) n, err := pipe.Read(b2)

View File

@ -10,7 +10,6 @@ import (
"net" "net"
"sync" "sync"
"testing" "testing"
"time"
) )
func serveEcho(l net.Listener) { func serveEcho(l net.Listener) {
@ -64,21 +63,20 @@ func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
return clientSession, serverSession, paris return clientSession, serverSession, paris
} }
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) { func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, conn := range conns { for _, conn := range conns {
wg.Add(1) wg.Add(1)
go func(conn net.Conn) { go func(conn net.Conn) {
testDataLen := rand.Intn(maxMsgLen) testData := make([]byte, msgLen)
testData := make([]byte, testDataLen)
rand.Read(testData) rand.Read(testData)
n, err := conn.Write(testData) n, err := conn.Write(testData)
if n != testDataLen { if n != msgLen {
t.Fatalf("written only %v, err %v", n, err) t.Fatalf("written only %v, err %v", n, err)
} }
recvBuf := make([]byte, testDataLen) recvBuf := make([]byte, msgLen)
_, err = io.ReadFull(conn, recvBuf) _, err = io.ReadFull(conn, recvBuf)
if err != nil { if err != nil {
t.Fatalf("failed to read back: %v", err) t.Fatalf("failed to read back: %v", err)
@ -96,7 +94,7 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
func TestMultiplex(t *testing.T) { func TestMultiplex(t *testing.T) {
const numStreams = 2000 // -race option limits the number of goroutines to 8192 const numStreams = 2000 // -race option limits the number of goroutines to 8192
const numConns = 4 const numConns = 4
const maxMsgLen = 16384 const msgLen = 16384
clientSession, serverSession, _ := makeSessionPair(numConns) clientSession, serverSession, _ := makeSessionPair(numConns)
go serveEcho(serverSession) go serveEcho(serverSession)
@ -111,15 +109,10 @@ func TestMultiplex(t *testing.T) {
} }
//test echo //test echo
runEchoTest(t, streams, maxMsgLen) runEchoTest(t, streams, msgLen)
assert.Eventuallyf(t, func() bool { assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
return clientSession.streamCount() == numStreams assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
}, time.Second, 10*time.Millisecond, "client stream count is wrong: %v", clientSession.streamCount())
assert.Eventuallyf(t, func() bool {
return serverSession.streamCount() == numStreams
}, time.Second, 10*time.Millisecond, "server stream count is wrong: %v", serverSession.streamCount())
// close one stream // close one stream
closing, streams := streams[0], streams[1:] closing, streams := streams[0], streams[1:]

View File

@ -12,7 +12,7 @@ import (
) )
type Obfser func(*Frame, []byte, int) (int, error) type Obfser func(*Frame, []byte, int) (int, error)
type Deobfser func([]byte) (*Frame, error) type Deobfser func(*Frame, []byte) error
var u32 = binary.BigEndian.Uint32 var u32 = binary.BigEndian.Uint32
var u64 = binary.BigEndian.Uint64 var u64 = binary.BigEndian.Uint64
@ -135,9 +135,9 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser {
func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
// frame header length + minimum data size (i.e. nonce size of salsa20) // frame header length + minimum data size (i.e. nonce size of salsa20)
const minInputLen = frameHeaderLength + salsa20NonceSize const minInputLen = frameHeaderLength + salsa20NonceSize
deobfs := func(in []byte) (*Frame, error) { deobfs := func(f *Frame, in []byte) error {
if len(in) < minInputLen { if len(in) < minInputLen {
return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) return fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen)
} }
header := in[:frameHeaderLength] header := in[:frameHeaderLength]
@ -153,7 +153,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
usefulPayloadLen := len(pldWithOverHead) - int(extraLen) usefulPayloadLen := len(pldWithOverHead) - int(extraLen)
if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) { if usefulPayloadLen < 0 || usefulPayloadLen > len(pldWithOverHead) {
return nil, errors.New("extra length is negative or extra length is greater than total pldWithOverHead length") return errors.New("extra length is negative or extra length is greater than total pldWithOverHead length")
} }
var outputPayload []byte var outputPayload []byte
@ -167,18 +167,16 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser {
} else { } else {
_, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil) _, err := payloadCipher.Open(pldWithOverHead[:0], header[:payloadCipher.NonceSize()], pldWithOverHead, nil)
if err != nil { if err != nil {
return nil, err return err
} }
outputPayload = pldWithOverHead[:usefulPayloadLen] outputPayload = pldWithOverHead[:usefulPayloadLen]
} }
ret := &Frame{ f.StreamID = streamID
StreamID: streamID, f.Seq = seq
Seq: seq, f.Closing = closing
Closing: closing, f.Payload = outputPayload
Payload: outputPayload, return nil
}
return ret, nil
} }
return deobfs return deobfs
} }

View File

@ -17,8 +17,7 @@ func TestGenerateObfs(t *testing.T) {
run := func(obfuscator Obfuscator, ct *testing.T) { run := func(obfuscator Obfuscator, ct *testing.T) {
obfsBuf := make([]byte, 512) obfsBuf := make([]byte, 512)
f := &Frame{} _testFrame, _ := quick.Value(reflect.TypeOf(&Frame{}), rand.New(rand.NewSource(42)))
_testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42)))
testFrame := _testFrame.Interface().(*Frame) testFrame := _testFrame.Interface().(*Frame)
i, err := obfuscator.Obfs(testFrame, obfsBuf, 0) i, err := obfuscator.Obfs(testFrame, obfsBuf, 0)
if err != nil { if err != nil {
@ -26,7 +25,8 @@ func TestGenerateObfs(t *testing.T) {
return return
} }
resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) var resultFrame Frame
err = obfuscator.Deobfs(&resultFrame, obfsBuf[:i])
if err != nil { if err != nil {
ct.Error("failed to deobfs ", err) ct.Error("failed to deobfs ", err)
return return
@ -148,10 +148,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.SetBytes(int64(n)) b.SetBytes(int64(n))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("AES128GCM", func(b *testing.B) { b.Run("AES128GCM", func(b *testing.B) {
@ -162,10 +163,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("plain", func(b *testing.B) { b.Run("plain", func(b *testing.B) {
@ -173,10 +175,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, nil) deobfs := MakeDeobfs(key, nil)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
b.Run("chacha20Poly1305", func(b *testing.B) { b.Run("chacha20Poly1305", func(b *testing.B) {
@ -186,10 +189,11 @@ func BenchmarkDeobfs(b *testing.B) {
n, _ := obfs(testFrame, obfsBuf, 0) n, _ := obfs(testFrame, obfsBuf, 0)
deobfs := MakeDeobfs(key, payloadCipher) deobfs := MakeDeobfs(key, payloadCipher)
frame := new(Frame)
b.ResetTimer() b.ResetTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
deobfs(obfsBuf[:n]) deobfs(frame, obfsBuf[:n])
} }
}) })
} }

View File

@ -15,7 +15,7 @@ type recvBuffer interface {
// when the buffer is empty. // when the buffer is empty.
io.ReadCloser io.ReadCloser
io.WriterTo io.WriterTo
Write(Frame) (toBeClosed bool, err error) Write(*Frame) (toBeClosed bool, err error)
SetReadDeadline(time time.Time) SetReadDeadline(time time.Time)
// SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing // SetWriteToTimeout sets the duration a recvBuffer waits in a WriteTo call when nothing
// has been written for a while. After that duration it should return ErrTimeout // has been written for a while. After that duration it should return ErrTimeout

View File

@ -63,7 +63,12 @@ type Session struct {
// atomic // atomic
activeStreamCount uint32 activeStreamCount uint32
streams sync.Map
streamsM sync.Mutex
streams map[uint32]*Stream
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
recvFramePool sync.Pool
// Switchboard manages all connections to remote // Switchboard manages all connections to remote
sb *switchboard sb *switchboard
@ -89,6 +94,8 @@ func MakeSession(id uint32, config SessionConfig) *Session {
SessionConfig: config, SessionConfig: config,
nextStreamID: 1, nextStreamID: 1,
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
recvFramePool: sync.Pool{New: func() interface{} { return &Frame{} }},
streams: map[uint32]*Stream{},
} }
sesh.addrs.Store([]net.Addr{nil, nil}) sesh.addrs.Store([]net.Addr{nil, nil})
@ -145,7 +152,9 @@ func (sesh *Session) OpenStream() (*Stream, error) {
return nil, errNoMultiplex return nil, errNoMultiplex
} }
stream := makeStream(sesh, id) stream := makeStream(sesh, id)
sesh.streams.Store(id, stream) sesh.streamsM.Lock()
sesh.streams[id] = stream
sesh.streamsM.Unlock()
sesh.streamCountIncr() sesh.streamCountIncr()
log.Tracef("stream %v of session %v opened", id, sesh.id) log.Tracef("stream %v of session %v opened", id, sesh.id)
return stream, nil return stream, nil
@ -165,24 +174,22 @@ func (sesh *Session) Accept() (net.Conn, error) {
} }
func (sesh *Session) closeStream(s *Stream, active bool) error { func (sesh *Session) closeStream(s *Stream, active bool) error {
// must be holding s.wirtingM on entry
if atomic.SwapUint32(&s.closed, 1) == 1 { if atomic.SwapUint32(&s.closed, 1) == 1 {
return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing) return fmt.Errorf("closing stream %v: %w", s.id, errRepeatStreamClosing)
} }
_ = s.recvBuf.Close() // recvBuf.Close should not return error _ = s.getRecvBuf().Close() // recvBuf.Close should not return error
if active { if active {
// Notify remote that this stream is closed // Notify remote that this stream is closed
padding := genRandomPadding() padding := genRandomPadding()
f := &Frame{ s.writingFrame.Closing = closingStream
StreamID: s.id, s.writingFrame.Payload = padding
Seq: s.nextSendSeq,
Closing: closingStream,
Payload: padding,
}
s.nextSendSeq++
obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead) obfsBuf := make([]byte, len(padding)+frameHeaderLength+sesh.Obfuscator.maxOverhead)
i, err := sesh.Obfs(f, obfsBuf, 0)
i, err := sesh.Obfs(&s.writingFrame, obfsBuf, 0)
s.writingFrame.Seq++
if err != nil { if err != nil {
return err return err
} }
@ -190,7 +197,7 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
if err != nil { if err != nil {
return err return err
} }
log.Tracef("stream %v actively closed. seq %v", s.id, f.Seq) log.Tracef("stream %v actively closed.", s.id)
} else { } else {
log.Tracef("stream %v passively closed", s.id) log.Tracef("stream %v passively closed", s.id)
} }
@ -198,7 +205,9 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
// We set it as nil to signify that the stream id had existed before. // We set it as nil to signify that the stream id had existed before.
// If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell // If we Delete(s.id) straight away, later on in recvDataFromRemote, it will not be able to tell
// if the frame it received was from a new stream or a dying stream whose frame arrived late // if the frame it received was from a new stream or a dying stream whose frame arrived late
sesh.streams.Store(s.id, nil) sesh.streamsM.Lock()
sesh.streams[s.id] = nil
sesh.streamsM.Unlock()
if sesh.streamCountDecr() == 0 { if sesh.streamCountDecr() == 0 {
if sesh.Singleplex { if sesh.Singleplex {
return sesh.Close() return sesh.Close()
@ -214,7 +223,10 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
// to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new // to the stream buffer, otherwise it fetches the desired stream instance, or creates and stores one if it's a new
// stream and then writes to the stream buffer // stream and then writes to the stream buffer
func (sesh *Session) recvDataFromRemote(data []byte) error { func (sesh *Session) recvDataFromRemote(data []byte) error {
frame, err := sesh.Deobfs(data) frame := sesh.recvFramePool.Get().(*Frame)
defer sesh.recvFramePool.Put(frame)
err := sesh.Deobfs(frame, data)
if err != nil { if err != nil {
return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err) return fmt.Errorf("Failed to decrypt a frame for session %v: %v", sesh.id, err)
} }
@ -224,19 +236,23 @@ func (sesh *Session) recvDataFromRemote(data []byte) error {
return sesh.passiveClose() return sesh.passiveClose()
} }
newStream := makeStream(sesh, frame.StreamID) sesh.streamsM.Lock()
existingStreamI, existing := sesh.streams.LoadOrStore(frame.StreamID, newStream) existingStream, existing := sesh.streams[frame.StreamID]
if existing { if existing {
if existingStreamI == nil { sesh.streamsM.Unlock()
if existingStream == nil {
// this is when the stream existed before but has since been closed. We do nothing // this is when the stream existed before but has since been closed. We do nothing
return nil return nil
} }
return existingStreamI.(*Stream).recvFrame(*frame) return existingStream.recvFrame(frame)
} else { } else {
newStream := makeStream(sesh, frame.StreamID)
sesh.streams[frame.StreamID] = newStream
sesh.streamsM.Unlock()
// new stream // new stream
sesh.streamCountIncr() sesh.streamCountIncr()
sesh.acceptCh <- newStream sesh.acceptCh <- newStream
return newStream.recvFrame(*frame) return newStream.recvFrame(frame)
} }
} }
@ -260,17 +276,17 @@ func (sesh *Session) closeSession(closeSwitchboard bool) error {
} }
sesh.acceptCh <- nil sesh.acceptCh <- nil
sesh.streams.Range(func(key, streamI interface{}) bool { sesh.streamsM.Lock()
if streamI == nil { for id, stream := range sesh.streams {
return true if stream == nil {
continue
} }
stream := streamI.(*Stream)
atomic.StoreUint32(&stream.closed, 1) atomic.StoreUint32(&stream.closed, 1)
_ = stream.recvBuf.Close() // will not block _ = stream.getRecvBuf().Close() // will not block
sesh.streams.Delete(key) delete(sesh.streams, id)
sesh.streamCountDecr() sesh.streamCountDecr()
return true }
}) sesh.streamsM.Unlock()
if closeSwitchboard { if closeSwitchboard {
sesh.sb.closeAll() sesh.sb.closeAll()

View File

@ -12,10 +12,9 @@ import (
"time" "time"
) )
var seshConfigOrdered = SessionConfig{} var seshConfigs = map[string]SessionConfig{
"ordered": {},
var seshConfigUnordered = SessionConfig{ "unordered": {Unordered: true},
Unordered: true,
} }
const testPayloadLen = 1024 const testPayloadLen = 1024
@ -43,40 +42,20 @@ func TestRecvDataFromRemote(t *testing.T) {
return ret return ret
} }
sessionTypes := []struct { encryptionMethods := map[string]Obfuscator{
name string "plain": MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey),
config SessionConfig "aes-gcm": MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
}{ "chacha20-poly1305": MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
{"ordered",
SessionConfig{}},
{"unordered",
SessionConfig{Unordered: true}},
} }
encryptionMethods := []struct { for seshType, seshConfig := range seshConfigs {
name string seshConfig := seshConfig
obfuscator Obfuscator t.Run(seshType, func(t *testing.T) {
}{ for method, obfuscator := range encryptionMethods {
{ obfuscator := obfuscator
"plain", t.Run(method, func(t *testing.T) {
MakeObfuscatorUnwrap(EncryptionMethodPlain, sessionKey), seshConfig.Obfuscator = obfuscator
}, sesh := MakeSession(0, seshConfig)
{
"aes-gcm",
MakeObfuscatorUnwrap(EncryptionMethodAESGCM, sessionKey),
},
{
"chacha20-poly1305",
MakeObfuscatorUnwrap(EncryptionMethodChaha20Poly1305, sessionKey),
},
}
for _, st := range sessionTypes {
t.Run(st.name, func(t *testing.T) {
for _, em := range encryptionMethods {
t.Run(em.name, func(t *testing.T) {
st.config.Obfuscator = em.obfuscator
sesh := MakeSession(0, st.config)
n, err := sesh.Obfs(f, obfsBuf, 0) n, err := sesh.Obfs(f, obfsBuf, 0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -116,8 +95,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
f1 := &Frame{ f1 := &Frame{
1, 1,
@ -131,7 +112,9 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 1: %v", err) t.Fatalf("receiving normal frame for stream 1: %v", err)
} }
_, ok := sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
_, ok := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if !ok { if !ok {
t.Fatal("failed to fetch stream 1 after receiving it") t.Fatal("failed to fetch stream 1 after receiving it")
} }
@ -151,8 +134,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving normal frame for stream 2: %v", err) t.Fatalf("receiving normal frame for stream 2: %v", err)
} }
s2I, ok := sesh.streams.Load(f2.StreamID) sesh.streamsM.Lock()
if s2I == nil || !ok { s2M, ok := sesh.streams[f2.StreamID]
sesh.streamsM.Unlock()
if s2M == nil || !ok {
t.Fatal("failed to fetch stream 2 after receiving it") t.Fatal("failed to fetch stream 2 after receiving it")
} }
if sesh.streamCount() != 2 { if sesh.streamCount() != 2 {
@ -171,8 +156,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1: %v", err) t.Fatalf("receiving stream closing frame for stream 1: %v", err)
} }
s1I, _ := sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
if s1I != nil { s1M, _ := sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Fatal("stream 1 still exist after receiving stream close") t.Fatal("stream 1 still exist after receiving stream close")
} }
s1, _ := sesh.Accept() s1, _ := sesh.Accept()
@ -198,8 +185,10 @@ func TestRecvDataFromRemote_Closing_InOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving stream closing frame for stream 1 %v", err) t.Fatalf("receiving stream closing frame for stream 1 %v", err)
} }
s1I, _ = sesh.streams.Load(f1.StreamID) sesh.streamsM.Lock()
if s1I != nil { s1M, _ = sesh.streams[f1.StreamID]
sesh.streamsM.Unlock()
if s1M != nil {
t.Error("stream 1 exists after receiving stream close for the second time") t.Error("stream 1 exists after receiving stream close for the second time")
} }
streamCount := sesh.streamCount() streamCount := sesh.streamCount()
@ -245,8 +234,10 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) seshConfig := seshConfigs["ordered"]
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
// receive stream 1 closing first // receive stream 1 closing first
f1CloseStream := &Frame{ f1CloseStream := &Frame{
@ -260,7 +251,9 @@ func TestRecvDataFromRemote_Closing_OutOfOrder(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err) t.Fatalf("receiving out of order stream closing frame for stream 1: %v", err)
} }
_, ok := sesh.streams.Load(f1CloseStream.StreamID) sesh.streamsM.Lock()
_, ok := sesh.streams[f1CloseStream.StreamID]
sesh.streamsM.Unlock()
if !ok { if !ok {
t.Fatal("stream 1 doesn't exist") t.Fatal("stream 1 doesn't exist")
} }
@ -300,8 +293,12 @@ func TestParallelStreams(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered) for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
numStreams := acceptBacklog numStreams := acceptBacklog
seqs := make([]*uint64, numStreams) seqs := make([]*uint64, numStreams)
@ -347,24 +344,27 @@ func TestParallelStreams(t *testing.T) {
wg.Wait() wg.Wait()
sc := int(sesh.streamCount()) sc := int(sesh.streamCount())
var count int var count int
sesh.streams.Range(func(_, s interface{}) bool { sesh.streamsM.Lock()
for _, s := range sesh.streams {
if s != nil { if s != nil {
count++ count++
} }
return true }
}) sesh.streamsM.Unlock()
if sc != count { if sc != count {
t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc) t.Errorf("broken referential integrety: actual %v, reference count: %v", count, sc)
} }
})
}
} }
func TestStream_SetReadDeadline(t *testing.T) { func TestStream_SetReadDeadline(t *testing.T) {
var sessionKey [32]byte for seshType, seshConfig := range seshConfigs {
rand.Read(sessionKey[:]) seshConfig := seshConfig
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) t.Run(seshType, func(t *testing.T) {
seshConfigOrdered.Obfuscator = obfuscator sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
testReadDeadline := func(sesh *Session) {
t.Run("read after deadline set", func(t *testing.T) { t.Run("read after deadline set", func(t *testing.T) {
stream, _ := sesh.OpenStream() stream, _ := sesh.OpenStream()
_ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second)) _ = stream.SetReadDeadline(time.Now().Add(-1 * time.Second))
@ -392,27 +392,27 @@ func TestStream_SetReadDeadline(t *testing.T) {
t.Error("Read did not unblock after deadline has passed") t.Error("Read did not unblock after deadline has passed")
} }
}) })
})
} }
sesh := MakeSession(0, seshConfigOrdered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
sesh = MakeSession(0, seshConfigUnordered)
sesh.AddConnection(connutil.Discard())
testReadDeadline(sesh)
} }
func TestSession_timeoutAfter(t *testing.T) { func TestSession_timeoutAfter(t *testing.T) {
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
seshConfigOrdered.InactivityTimeout = 100 * time.Millisecond for seshType, seshConfig := range seshConfigs {
sesh := MakeSession(0, seshConfigOrdered) seshConfig := seshConfig
t.Run(seshType, func(t *testing.T) {
seshConfig.Obfuscator = obfuscator
seshConfig.InactivityTimeout = 100 * time.Millisecond
sesh := MakeSession(0, seshConfig)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
return sesh.IsClosed() return sesh.IsClosed()
}, 5*seshConfigOrdered.InactivityTimeout, seshConfigOrdered.InactivityTimeout, "session should have timed out") }, 5*seshConfig.InactivityTimeout, seshConfig.InactivityTimeout, "session should have timed out")
})
}
} }
func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) { func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
@ -424,47 +424,73 @@ func BenchmarkRecvDataFromRemote_Ordered(b *testing.B) {
0, 0,
testPayload, testPayload,
} }
obfsBuf := make([]byte, obfsBufLen)
var sessionKey [32]byte var sessionKey [32]byte
rand.Read(sessionKey[:]) rand.Read(sessionKey[:])
b.Run("plain", func(b *testing.B) { table := map[string]byte{
obfuscator, _ := MakeObfuscator(EncryptionMethodPlain, sessionKey) "plain": EncryptionMethodPlain,
seshConfigOrdered.Obfuscator = obfuscator "aes-gcm": EncryptionMethodAESGCM,
sesh := MakeSession(0, seshConfigOrdered) "chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
const maxIter = 100_000 // run with -benchtime 100000x to avoid index out of bounds panic
for name, ep := range table {
ep := ep
b.Run(name, func(b *testing.B) {
seshConfig := seshConfigs["ordered"]
obfuscator, _ := MakeObfuscator(ep, sessionKey)
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
binaryFrames := [maxIter][]byte{}
for i := 0; i < maxIter; i++ {
obfsBuf := make([]byte, obfsBufLen)
n, _ := sesh.Obfs(f, obfsBuf, 0) n, _ := sesh.Obfs(f, obfsBuf, 0)
binaryFrames[i] = obfsBuf[:n]
f.Seq++
}
b.SetBytes(int64(len(f.Payload))) b.SetBytes(int64(len(f.Payload)))
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n]) sesh.recvDataFromRemote(binaryFrames[i])
}
})
}
}
func BenchmarkMultiStreamWrite(b *testing.B) {
var sessionKey [32]byte
rand.Read(sessionKey[:])
table := map[string]byte{
"plain": EncryptionMethodPlain,
"aes-gcm": EncryptionMethodAESGCM,
"chacha20poly1305": EncryptionMethodChaha20Poly1305,
}
testPayload := make([]byte, testPayloadLen)
for name, ep := range table {
b.Run(name, func(b *testing.B) {
for seshType, seshConfig := range seshConfigs {
seshConfig := seshConfig
b.Run(seshType, func(b *testing.B) {
obfuscator, _ := MakeObfuscator(ep, sessionKey)
seshConfig.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfig)
sesh.AddConnection(connutil.Discard())
b.ResetTimer()
b.SetBytes(testPayloadLen)
b.RunParallel(func(pb *testing.PB) {
stream, _ := sesh.OpenStream()
for pb.Next() {
stream.Write(testPayload)
}
})
})
} }
}) })
b.Run("aes-gcm", func(b *testing.B) {
obfuscator, _ := MakeObfuscator(EncryptionMethodAESGCM, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
} }
})
b.Run("chacha20-poly1305", func(b *testing.B) {
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
seshConfigOrdered.Obfuscator = obfuscator
sesh := MakeSession(0, seshConfigOrdered)
n, _ := sesh.Obfs(f, obfsBuf, 0)
b.SetBytes(int64(len(f.Payload)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sesh.recvDataFromRemote(obfsBuf[:n])
}
})
} }

View File

@ -23,21 +23,20 @@ type Stream struct {
session *Session session *Session
allocIdempot sync.Once
// a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't // a buffer (implemented as an asynchronous buffered pipe) to put data we've received from recvFrame but hasn't
// been read by the consumer through Read or WriteTo // been read by the consumer through Read or WriteTo. Lazily allocated
recvBuf recvBuffer recvBuf recvBuffer
writingM sync.Mutex writingM sync.Mutex
nextSendSeq uint64 writingFrame Frame // we do the allocation here to save repeated allocations in Write and ReadFrom
// atomic // atomic
closed uint32 closed uint32
// lazy allocation for obfsBuf. This is desirable because obfsBuf is only used when data is sent from // obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste // the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
// memory // memory
allocIdempot sync.Once
// obfuscation happens in this buffer
obfsBuf []byte obfsBuf []byte
// When we want order guarantee (i.e. session.Unordered is false), // When we want order guarantee (i.e. session.Unordered is false),
@ -52,17 +51,14 @@ type Stream struct {
} }
func makeStream(sesh *Session, id uint32) *Stream { func makeStream(sesh *Session, id uint32) *Stream {
var recvBuf recvBuffer
if sesh.Unordered {
recvBuf = NewDatagramBufferedPipe()
} else {
recvBuf = NewStreamBuffer()
}
stream := &Stream{ stream := &Stream{
id: id, id: id,
session: sesh, session: sesh,
recvBuf: recvBuf, writingFrame: Frame{
StreamID: id,
Seq: 0,
Closing: closingNothing,
},
} }
return stream return stream
@ -70,9 +66,20 @@ func makeStream(sesh *Session, id uint32) *Stream {
func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 }
func (s *Stream) getRecvBuf() recvBuffer {
s.allocIdempot.Do(func() {
if s.session.Unordered {
s.recvBuf = NewDatagramBufferedPipe()
} else {
s.recvBuf = NewStreamBuffer()
}
})
return s.recvBuf
}
// receive a readily deobfuscated Frame so its payload can later be Read // receive a readily deobfuscated Frame so its payload can later be Read
func (s *Stream) recvFrame(frame Frame) error { func (s *Stream) recvFrame(frame *Frame) error {
toBeClosed, err := s.recvBuf.Write(frame) toBeClosed, err := s.getRecvBuf().Write(frame)
if toBeClosed { if toBeClosed {
err = s.passiveClose() err = s.passiveClose()
if errors.Is(err, errRepeatStreamClosing) { if errors.Is(err, errRepeatStreamClosing) {
@ -91,7 +98,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
return 0, nil return 0, nil
} }
n, err = s.recvBuf.Read(buf) n, err = s.getRecvBuf().Read(buf)
log.Tracef("%v read from stream %v with err %v", n, s.id, err) log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
@ -102,7 +109,7 @@ func (s *Stream) Read(buf []byte) (n int, err error) {
// WriteTo continuously write data Stream has received into the writer w. // WriteTo continuously write data Stream has received into the writer w.
func (s *Stream) WriteTo(w io.Writer) (int64, error) { func (s *Stream) WriteTo(w io.Writer) (int64, error) {
// will keep writing until the underlying buffer is closed // will keep writing until the underlying buffer is closed
n, err := s.recvBuf.WriteTo(w) n, err := s.getRecvBuf().WriteTo(w)
log.Tracef("%v read from stream %v with err %v", n, s.id, err) log.Tracef("%v read from stream %v with err %v", n, s.id, err)
if err == io.EOF { if err == io.EOF {
return n, ErrBrokenStream return n, ErrBrokenStream
@ -110,15 +117,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
return n, nil return n, nil
} }
func (s *Stream) obfuscateAndSend(f *Frame, payloadOffsetInObfsBuf int) error { func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
var cipherTextLen int var cipherTextLen int
cipherTextLen, err := s.session.Obfs(f, s.obfsBuf, payloadOffsetInObfsBuf) cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
if err != nil { if err != nil {
return err return err
} }
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId) _, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v. seq: %v", len(f.Payload), s.id, err, f.Seq)
if err != nil { if err != nil {
if err == errBrokenSwitchboard { if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error()) s.session.SetTerminalMsg(err.Error())
@ -154,14 +160,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
} }
framePayload = in[n : s.session.maxStreamUnitWrite+n] framePayload = in[n : s.session.maxStreamUnitWrite+n]
} }
f := &Frame{ s.writingFrame.Payload = framePayload
StreamID: s.id, err = s.obfuscateAndSend(0)
Seq: s.nextSendSeq, s.writingFrame.Seq++
Closing: closingNothing,
Payload: framePayload,
}
s.nextSendSeq++
err = s.obfuscateAndSend(f, 0)
if err != nil { if err != nil {
return return
} }
@ -193,14 +194,9 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
} }
s.writingM.Lock() s.writingM.Lock()
f := &Frame{ s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
StreamID: s.id, err = s.obfuscateAndSend(frameHeaderLength)
Seq: s.nextSendSeq, s.writingFrame.Seq++
Closing: closingNothing,
Payload: s.obfsBuf[frameHeaderLength : frameHeaderLength+read],
}
s.nextSendSeq++
err = s.obfuscateAndSend(f, frameHeaderLength)
s.writingM.Unlock() s.writingM.Unlock()
if err != nil { if err != nil {
@ -225,8 +221,8 @@ func (s *Stream) Close() error {
func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] } func (s *Stream) LocalAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[0] }
func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] } func (s *Stream) RemoteAddr() net.Addr { return s.session.addrs.Load().([]net.Addr)[1] }
func (s *Stream) SetWriteToTimeout(d time.Duration) { s.recvBuf.SetWriteToTimeout(d) } func (s *Stream) SetWriteToTimeout(d time.Duration) { s.getRecvBuf().SetWriteToTimeout(d) }
func (s *Stream) SetReadDeadline(t time.Time) error { s.recvBuf.SetReadDeadline(t); return nil } func (s *Stream) SetReadDeadline(t time.Time) error { s.getRecvBuf().SetReadDeadline(t); return nil }
func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d } func (s *Stream) SetReadFromTimeout(d time.Duration) { s.readFromTimeout = d }
var errNotImplemented = errors.New("Not implemented") var errNotImplemented = errors.New("Not implemented")

View File

@ -63,7 +63,7 @@ func NewStreamBuffer() *streamBuffer {
return sb return sb
} }
func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) { func (sb *streamBuffer) Write(f *Frame) (toBeClosed bool, err error) {
sb.recvM.Lock() sb.recvM.Lock()
defer sb.recvM.Unlock() defer sb.recvM.Unlock()
// when there'fs no ooo packages in heap and we receive the next package in order // when there'fs no ooo packages in heap and we receive the next package in order
@ -81,10 +81,11 @@ func (sb *streamBuffer) Write(f Frame) (toBeClosed bool, err error) {
return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq) return false, fmt.Errorf("seq %v is smaller than nextRecvSeq %v", f.Seq, sb.nextRecvSeq)
} }
heap.Push(&sb.sh, &f) saved := *f
heap.Push(&sb.sh, &saved)
// Keep popping from the heap until empty or to the point that the wanted seq was not received // Keep popping from the heap until empty or to the point that the wanted seq was not received
for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq { for len(sb.sh) > 0 && sb.sh[0].Seq == sb.nextRecvSeq {
f = *heap.Pop(&sb.sh).(*Frame) f = heap.Pop(&sb.sh).(*Frame)
if f.Closing != closingNothing { if f.Closing != closingNothing {
return true, nil return true, nil
} else { } else {

View File

@ -20,11 +20,10 @@ func TestRecvNewFrame(t *testing.T) {
for _, n := range set { for _, n := range set {
bu64 := make([]byte, 8) bu64 := make([]byte, 8)
binary.BigEndian.PutUint64(bu64, n) binary.BigEndian.PutUint64(bu64, n)
frame := Frame{ sb.Write(&Frame{
Seq: n, Seq: n,
Payload: bu64, Payload: bu64,
} })
sb.Write(frame)
} }
var sortedResult []uint64 var sortedResult []uint64
@ -80,7 +79,7 @@ func TestStreamBuffer_RecvThenClose(t *testing.T) {
Closing: 0, Closing: 0,
Payload: testData, Payload: testData,
} }
sb.Write(testFrame) sb.Write(&testFrame)
sb.Close() sb.Close()
readBuf := make([]byte, testDataLen) readBuf := make([]byte, testDataLen)

View File

@ -151,19 +151,31 @@ func TestStream_Close(t *testing.T) {
t.Error("failed to accept stream", err) t.Error("failed to accept stream", err)
return return
} }
// we read something to wait for the test frame to reach our recvBuffer.
// if it's empty by the point we call stream.Close(), the incoming
// frame will be dropped
readBuf := make([]byte, len(testPayload))
_, err = io.ReadFull(stream, readBuf[:1])
if err != nil {
t.Errorf("can't read any data before active closing")
}
err = stream.Close() err = stream.Close()
if err != nil { if err != nil {
t.Error("failed to actively close stream", err) t.Error("failed to actively close stream", err)
return return
} }
if sI, _ := sesh.streams.Load(stream.(*Stream).id); sI != nil { sesh.streamsM.Lock()
if s, _ := sesh.streams[stream.(*Stream).id]; s != nil {
sesh.streamsM.Unlock()
t.Error("stream still exists") t.Error("stream still exists")
return return
} }
sesh.streamsM.Unlock()
readBuf := make([]byte, len(testPayload)) _, err = io.ReadFull(stream, readBuf[1:])
_, err = io.ReadFull(stream, readBuf)
if err != nil { if err != nil {
t.Errorf("can't read residual data %v", err) t.Errorf("can't read residual data %v", err)
} }
@ -233,8 +245,10 @@ func TestStream_Close(t *testing.T) {
} }
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
sI, _ := sesh.streams.Load(stream.(*Stream).id) sesh.streamsM.Lock()
return sI == nil s, _ := sesh.streams[stream.(*Stream).id]
sesh.streamsM.Unlock()
return s == nil
}, time.Second, 10*time.Millisecond, "streams still exists") }, time.Second, 10*time.Millisecond, "streams still exists")
}) })

View File

@ -79,7 +79,13 @@ func (TLS) unmarshalClientHello(ch *ClientHello, staticPv crypto.PrivateKey) (fr
return return
} }
copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) var sharedSecret []byte
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
}
copy(fragments.sharedSecret[:], sharedSecret)
var keyShare []byte var keyShare []byte
keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}]) keyShare, err = parseKeyShare(ch.extensions[[2]byte{0x00, 0x33}])
if err != nil { if err != nil {

View File

@ -143,10 +143,15 @@ func InitState(preParse RawConfig, worldState common.WorldState) (sta *State, er
err = errors.New("command & control mode not implemented") err = errors.New("command & control mode not implemented")
return return
} else { } else {
manager, err := usermanager.MakeLocalManager(preParse.DatabasePath, worldState) var manager usermanager.UserManager
if len(preParse.AdminUID) == 0 || preParse.DatabasePath == "" {
manager = &usermanager.Voidmanager{}
} else {
manager, err = usermanager.MakeLocalManager(preParse.DatabasePath, worldState)
if err != nil { if err != nil {
return sta, err return sta, err
} }
}
sta.Panel = MakeUserPanel(manager) sta.Panel = MakeUserPanel(manager)
} }

View File

@ -40,6 +40,7 @@ const (
var ErrUserNotFound = errors.New("UID does not correspond to a user") var ErrUserNotFound = errors.New("UID does not correspond to a user")
var ErrSessionsCapReached = errors.New("Sessions cap has reached") var ErrSessionsCapReached = errors.New("Sessions cap has reached")
var ErrMangerIsVoid = errors.New("cannot perform operation with user manager as database path is not specified")
var ErrNoUpCredit = errors.New("No upload credit left") var ErrNoUpCredit = errors.New("No upload credit left")
var ErrNoDownCredit = errors.New("No download credit left") var ErrNoDownCredit = errors.New("No download credit left")

View File

@ -0,0 +1,31 @@
package usermanager
type Voidmanager struct{}
func (v *Voidmanager) AuthenticateUser(bytes []byte) (int64, int64, error) {
return 0, 0, ErrMangerIsVoid
}
func (v *Voidmanager) AuthoriseNewSession(bytes []byte, info AuthorisationInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) UploadStatus(updates []StatusUpdate) ([]StatusResponse, error) {
return nil, ErrMangerIsVoid
}
func (v *Voidmanager) ListAllUsers() ([]UserInfo, error) {
return nil, ErrMangerIsVoid
}
func (v *Voidmanager) GetUserInfo(UID []byte) (UserInfo, error) {
return UserInfo{}, ErrMangerIsVoid
}
func (v *Voidmanager) WriteUserInfo(info UserInfo) error {
return ErrMangerIsVoid
}
func (v *Voidmanager) DeleteUser(UID []byte) error {
return ErrMangerIsVoid
}

View File

@ -0,0 +1,43 @@
package usermanager
import (
"github.com/stretchr/testify/assert"
"testing"
)
var v = &Voidmanager{}
func Test_Voidmanager_AuthenticateUser(t *testing.T) {
_, _, err := v.AuthenticateUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_AuthoriseNewSession(t *testing.T) {
err := v.AuthoriseNewSession([]byte{}, AuthorisationInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_DeleteUser(t *testing.T) {
err := v.DeleteUser([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_GetUserInfo(t *testing.T) {
_, err := v.GetUserInfo([]byte{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_ListAllUsers(t *testing.T) {
_, err := v.ListAllUsers()
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_UploadStatus(t *testing.T) {
_, err := v.UploadStatus([]StatusUpdate{})
assert.Equal(t, ErrMangerIsVoid, err)
}
func Test_Voidmanager_WriteUserInfo(t *testing.T) {
err := v.WriteUserInfo(UserInfo{})
assert.Equal(t, ErrMangerIsVoid, err)
}

View File

@ -185,6 +185,9 @@ func (panel *userPanel) commitUpdate() error {
panel.usageUpdateQueue = make(map[[16]byte]*usagePair) panel.usageUpdateQueue = make(map[[16]byte]*usagePair)
panel.usageUpdateQueueM.Unlock() panel.usageUpdateQueueM.Unlock()
if len(statuses) == 0 {
return nil
}
responses, err := panel.Manager.UploadStatus(statuses) responses, err := panel.Manager.UploadStatus(statuses)
if err != nil { if err != nil {
return err return err

View File

@ -84,7 +84,13 @@ func (WebSocket) unmarshalHidden(hidden []byte, staticPv crypto.PrivateKey) (fra
return return
} }
copy(fragments.sharedSecret[:], ecdh.GenerateSharedSecret(staticPv, ephPub)) var sharedSecret []byte
sharedSecret, err = ecdh.GenerateSharedSecret(staticPv, ephPub)
if err != nil {
return
}
copy(fragments.sharedSecret[:], sharedSecret)
if len(hidden[32:]) != 64 { if len(hidden[32:]) != 64 {
err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:])) err = fmt.Errorf("%v: %v", ErrCiphertextLength, len(hidden[32:]))

View File

@ -12,10 +12,8 @@ import (
"github.com/cbeuw/connutil" "github.com/cbeuw/connutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -24,8 +22,6 @@ import (
) )
const numConns = 200 // -race option limits the number of goroutines to 8192 const numConns = 200 // -race option limits the number of goroutines to 8192
const delayBeforeTestingConnClose = 500 * time.Millisecond
const connCloseRetries = 3
func serveTCPEcho(l net.Listener) { func serveTCPEcho(l net.Listener) {
for { for {
@ -137,15 +133,13 @@ func generateClientConfigs(rawConfig client.RawConfig, state common.WorldState)
return lcl, rmt, auth return lcl, rmt, auth
} }
func basicServerState(ws common.WorldState, db *os.File) *server.State { func basicServerState(ws common.WorldState) *server.State {
var serverConfig = server.RawConfig{ var serverConfig = server.RawConfig{
ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}}, ProxyBook: map[string][]string{"shadowsocks": {"tcp", "fake.com:9999"}, "openvpn": {"udp", "fake.com:9999"}},
BindAddr: []string{"fake.com:9999"}, BindAddr: []string{"fake.com:9999"},
BypassUID: [][]byte{bypassUID[:]}, BypassUID: [][]byte{bypassUID[:]},
RedirAddr: "fake.com:9999", RedirAddr: "fake.com:9999",
PrivateKey: privateKey, PrivateKey: privateKey,
AdminUID: nil,
DatabasePath: db.Name(),
KeepAlive: 15, KeepAlive: 15,
CncMode: false, CncMode: false,
} }
@ -258,13 +252,11 @@ func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
} }
func TestUDP(t *testing.T) { func TestUDP(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicUDPConfig, worldState)
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
@ -319,9 +311,7 @@ func TestTCPSingleplex(t *testing.T) {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState) lcc, rcc, ai := generateClientConfigs(singleplexTCPConfig, worldState)
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") sta := basicServerState(worldState)
defer os.Remove(tmpDB.Name())
sta := basicServerState(worldState, tmpDB)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -381,9 +371,7 @@ func TestTCPMultiplex(t *testing.T) {
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
var tmpDB, _ = ioutil.TempFile("", "ck_user_info") sta := basicServerState(worldState)
defer os.Remove(tmpDB.Name())
sta := basicServerState(worldState, tmpDB)
proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, netToCkServerD, redirFromCkServerL, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
@ -456,11 +444,8 @@ func TestClosingStreamsFromProxy(t *testing.T) {
clientConfig := clientConfig clientConfig := clientConfig
clientConfigName := clientConfigName clientConfigName := clientConfigName
t.Run(clientConfigName, func(t *testing.T) { t.Run(clientConfigName, func(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
lcc, rcc, ai := generateClientConfigs(clientConfig, worldState) lcc, rcc, ai := generateClientConfigs(clientConfig, worldState)
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState)
proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta) proxyToCkClientD, proxyFromCkServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -519,12 +504,10 @@ func TestClosingStreamsFromProxy(t *testing.T) {
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkThroughput(b *testing.B) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0)) worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState) lcc, rcc, ai := generateClientConfigs(basicTCPConfig, worldState)
sta := basicServerState(worldState, tmpDB) sta := basicServerState(worldState)
const bufSize = 16 * 1024 const bufSize = 16 * 1024
encryptionMethods := map[string]byte{ encryptionMethods := map[string]byte{

View File

@ -1,3 +1,5 @@
#!/usr/bin/env bash
go get github.com/mitchellh/gox go get github.com/mitchellh/gox
mkdir -p release mkdir -p release