mirror of https://github.com/cbeuw/Cloak
Use sync.Pool for obfuscation buffer
This commit is contained in:
parent
5933ad8781
commit
881f6e6f9d
|
|
@ -70,6 +70,8 @@ type Session struct {
|
||||||
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
// a pool of heap allocated frame objects so we don't have to allocate a new one each time we receive a frame
|
||||||
recvFramePool sync.Pool
|
recvFramePool sync.Pool
|
||||||
|
|
||||||
|
streamObfsBufPool sync.Pool
|
||||||
|
|
||||||
// Switchboard manages all connections to remote
|
// Switchboard manages all connections to remote
|
||||||
sb *switchboard
|
sb *switchboard
|
||||||
|
|
||||||
|
|
@ -117,6 +119,11 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||||
// todo: validation. this must be smaller than StreamSendBufferSize
|
// todo: validation. this must be smaller than StreamSendBufferSize
|
||||||
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
|
sesh.maxStreamUnitWrite = sesh.MsgOnWireSizeLimit - frameHeaderLength - sesh.Obfuscator.maxOverhead
|
||||||
|
|
||||||
|
sesh.streamObfsBufPool = sync.Pool{New: func() interface{} {
|
||||||
|
b := make([]byte, sesh.StreamSendBufferSize)
|
||||||
|
return &b
|
||||||
|
}}
|
||||||
|
|
||||||
sesh.sb = makeSwitchboard(sesh)
|
sesh.sb = makeSwitchboard(sesh)
|
||||||
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
time.AfterFunc(sesh.InactivityTimeout, sesh.checkTimeout)
|
||||||
return sesh
|
return sesh
|
||||||
|
|
@ -180,25 +187,20 @@ func (sesh *Session) closeStream(s *Stream, active bool) error {
|
||||||
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
_ = s.getRecvBuf().Close() // recvBuf.Close should not return error
|
||||||
|
|
||||||
if active {
|
if active {
|
||||||
tmpBuf := make([]byte, 256+frameHeaderLength+sesh.Obfuscator.maxOverhead)
|
tmpBuf := sesh.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
|
||||||
// Notify remote that this stream is closed
|
// Notify remote that this stream is closed
|
||||||
common.CryptoRandRead(tmpBuf[:1])
|
common.CryptoRandRead((*tmpBuf)[:1])
|
||||||
padLen := int(tmpBuf[0]) + 1
|
padLen := int((*tmpBuf)[0]) + 1
|
||||||
payload := tmpBuf[frameHeaderLength : padLen+frameHeaderLength]
|
payload := (*tmpBuf)[frameHeaderLength : padLen+frameHeaderLength]
|
||||||
common.CryptoRandRead(payload)
|
common.CryptoRandRead(payload)
|
||||||
|
|
||||||
// must be holding s.wirtingM on entry
|
// must be holding s.wirtingM on entry
|
||||||
s.writingFrame.Closing = closingStream
|
s.writingFrame.Closing = closingStream
|
||||||
s.writingFrame.Payload = payload
|
s.writingFrame.Payload = payload
|
||||||
|
|
||||||
cipherTextLen, err := sesh.Obfs(&s.writingFrame, tmpBuf, frameHeaderLength)
|
err := s.obfuscateAndSend(*tmpBuf, frameHeaderLength)
|
||||||
s.writingFrame.Seq++
|
sesh.streamObfsBufPool.Put(tmpBuf)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = sesh.sb.send(tmpBuf[:cipherTextLen], &s.assignedConnId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,6 @@ type Stream struct {
|
||||||
// atomic
|
// atomic
|
||||||
closed uint32
|
closed uint32
|
||||||
|
|
||||||
// obfuscation happens in obfsBuf. This buffer is lazily allocated as obfsBuf is only used when data is sent from
|
|
||||||
// the stream (through Write or ReadFrom). Some streams never send data so eager allocation will waste
|
|
||||||
// memory
|
|
||||||
obfsBuf []byte
|
|
||||||
|
|
||||||
// When we want order guarantee (i.e. session.Unordered is false),
|
// When we want order guarantee (i.e. session.Unordered is false),
|
||||||
// we assign each stream a fixed underlying connection.
|
// we assign each stream a fixed underlying connection.
|
||||||
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
|
// If the underlying connections the session uses provide ordering guarantee (most likely TCP),
|
||||||
|
|
@ -117,13 +112,14 @@ func (s *Stream) WriteTo(w io.Writer) (int64, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) obfuscateAndSend(payloadOffsetInObfsBuf int) error {
|
func (s *Stream) obfuscateAndSend(buf []byte, payloadOffsetInBuf int) error {
|
||||||
cipherTextLen, err := s.session.Obfs(&s.writingFrame, s.obfsBuf, payloadOffsetInObfsBuf)
|
cipherTextLen, err := s.session.Obfs(&s.writingFrame, buf, payloadOffsetInBuf)
|
||||||
|
s.writingFrame.Seq++
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
|
_, err = s.session.sb.send(buf[:cipherTextLen], &s.assignedConnId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errBrokenSwitchboard {
|
if err == errBrokenSwitchboard {
|
||||||
s.session.SetTerminalMsg(err.Error())
|
s.session.SetTerminalMsg(err.Error())
|
||||||
|
|
@ -142,9 +138,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
return 0, ErrBrokenStream
|
return 0, ErrBrokenStream
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.obfsBuf == nil {
|
|
||||||
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
|
|
||||||
}
|
|
||||||
for n < len(in) {
|
for n < len(in) {
|
||||||
var framePayload []byte
|
var framePayload []byte
|
||||||
if len(in)-n <= s.session.maxStreamUnitWrite {
|
if len(in)-n <= s.session.maxStreamUnitWrite {
|
||||||
|
|
@ -160,8 +153,9 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
||||||
}
|
}
|
||||||
s.writingFrame.Payload = framePayload
|
s.writingFrame.Payload = framePayload
|
||||||
err = s.obfuscateAndSend(0)
|
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||||
s.writingFrame.Seq++
|
err = s.obfuscateAndSend(*buf, 0)
|
||||||
|
s.session.streamObfsBufPool.Put(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -173,9 +167,6 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||||
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
// ReadFrom continuously read data from r and send it off, until either r returns error or nothing has been read
|
||||||
// for readFromTimeout amount of time
|
// for readFromTimeout amount of time
|
||||||
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
if s.obfsBuf == nil {
|
|
||||||
s.obfsBuf = make([]byte, s.session.StreamSendBufferSize)
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
if s.readFromTimeout != 0 {
|
if s.readFromTimeout != 0 {
|
||||||
if rder, ok := r.(net.Conn); !ok {
|
if rder, ok := r.(net.Conn); !ok {
|
||||||
|
|
@ -184,7 +175,8 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
|
rder.SetReadDeadline(time.Now().Add(s.readFromTimeout))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
read, er := r.Read(s.obfsBuf[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
|
buf := s.session.streamObfsBufPool.Get().(*[]byte)
|
||||||
|
read, er := r.Read((*buf)[frameHeaderLength : frameHeaderLength+s.session.maxStreamUnitWrite])
|
||||||
if er != nil {
|
if er != nil {
|
||||||
return n, er
|
return n, er
|
||||||
}
|
}
|
||||||
|
|
@ -196,10 +188,10 @@ func (s *Stream) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writingM.Lock()
|
s.writingM.Lock()
|
||||||
s.writingFrame.Payload = s.obfsBuf[frameHeaderLength : frameHeaderLength+read]
|
s.writingFrame.Payload = (*buf)[frameHeaderLength : frameHeaderLength+read]
|
||||||
err = s.obfuscateAndSend(frameHeaderLength)
|
err = s.obfuscateAndSend(*buf, frameHeaderLength)
|
||||||
s.writingFrame.Seq++
|
|
||||||
s.writingM.Unlock()
|
s.writingM.Unlock()
|
||||||
|
s.session.streamObfsBufPool.Put(buf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue