diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index c146b26..23be6fa 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -6,6 +6,8 @@ import ( "sync" "sync/atomic" "time" + + log "github.com/sirupsen/logrus" ) const ( @@ -105,27 +107,30 @@ func (sesh *Session) delStream(id uint32) { sesh.streamsM.Unlock() } -// either fetch an existing stream or instantiate a new stream and put it in the dict, and return it -func (sesh *Session) getStream(id uint32, closingFrame bool) *Stream { - // it would have been neater to use defer Unlock(), however it gives - // non-negligable overhead and this function is performance critical +func (sesh *Session) recvDataFromRemote(data []byte) { + frame, err := sesh.Deobfs(data) + if err != nil { + log.Debugf("Failed to decrypt a frame for session %v: %v", sesh.id, err) + } + sesh.streamsM.Lock() defer sesh.streamsM.Unlock() - stream := sesh.streams[id] - if stream != nil { - return stream + stream, existing := sesh.streams[frame.StreamID] + if existing { + stream.writeFrame(frame) } else { - if closingFrame { - // If the stream has been closed and the current frame is a closing frame, - // we return nil - return nil + if frame.Closing == 1 { + // If the stream has been closed and the current frame is a closing frame, we do noop + return } else { - stream = makeStream(id, sesh) - sesh.streams[id] = stream + stream = makeStream(frame.StreamID, sesh) + sesh.streams[frame.StreamID] = stream sesh.acceptCh <- stream - return stream + stream.writeFrame(frame) + return } } + } func (sesh *Session) SetTerminalMsg(msg string) { @@ -156,6 +161,7 @@ func (sesh *Session) Close() error { sesh.streamsM.Unlock() sesh.sb.closeAll() + log.Debugf("session %v closed gracefully", sesh.id) return nil } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index cbc5b7b..c8b4611 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -51,6 +51,8 @@ func makeStream(id uint32, sesh *Session) *Stream { func (s *Stream) isClosed() bool { return atomic.LoadUint32(&s.closed) == 1 } +func (s *Stream) writeFrame(frame *Frame) { s.sorter.writeNewFrame(frame) } + func (s *Stream) Read(buf []byte) (n int, err error) { if len(buf) == 0 { if s.isClosed() { @@ -94,7 +96,7 @@ func (s *Stream) Write(in []byte) (n int, err error) { if err != nil { return i, err } - n, err = s.session.sb.send(s.obfsBuf[:i]) + n, err = s.session.sb.Write(s.obfsBuf[:i]) return } @@ -137,7 +139,7 @@ func (s *Stream) Close() error { if err != nil { return err } - _, err = s.session.sb.send(s.obfsBuf[:i]) + _, err = s.session.sb.Write(s.obfsBuf[:i]) if err != nil { return err } diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index a0788b7..79b03d2 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -18,6 +18,8 @@ type switchboard struct { optimum atomic.Value // *connEnclave cesM sync.RWMutex ces []*connEnclave + + broken uint32 } func (sb *switchboard) getOptimum() *connEnclave { @@ -48,9 +50,13 @@ func makeSwitchboard(sesh *Session, valve *Valve) *switchboard { return sb } -var errNilOptimum error = errors.New("The optimal connection is nil") +var errNilOptimum = errors.New("The optimal connection is nil") +var errBrokenSwitchboard = errors.New("the switchboard is broken") -func (sb *switchboard) send(data []byte) (int, error) { +func (sb *switchboard) Write(data []byte) (int, error) { + if atomic.LoadUint32(&sb.broken) == 1 { + return 0, errBrokenSwitchboard + } ce := sb.getOptimum() if ce == nil { return 0, errNilOptimum @@ -104,17 +110,20 @@ func (sb *switchboard) removeConn(closing *connEnclave) { break } } - if len(sb.ces) == 0 { - sb.session.SetTerminalMsg("no underlying connection left") - sb.cesM.Unlock() - sb.session.Close() - return - } + remaining := len(sb.ces) sb.cesM.Unlock() + if remaining == 0 { + atomic.StoreUint32(&sb.broken, 1) + sb.session.SetTerminalMsg("no underlying connection left") + sb.session.Close() + } } // actively triggered by session.Close() func (sb *switchboard) closeAll() { + if atomic.SwapUint32(&sb.broken, 1) == 1 { + return + } sb.cesM.RLock() for _, ce := range sb.ces { ce.remoteConn.Close() @@ -122,8 +131,7 @@ func (sb *switchboard) closeAll() { sb.cesM.RUnlock() } -// deplex function costantly reads from a TCP connection, call Deobfs and distribute it -// to the corresponding stream +// deplex function costantly reads from a TCP connection func (sb *switchboard) deplex(ce *connEnclave) { buf := make([]byte, 20480) for { @@ -137,18 +145,6 @@ func (sb *switchboard) deplex(ce *connEnclave) { return } - frame, err := sb.session.Deobfs(buf[:n]) - if err != nil { - log.Debugf("Failed to decrypt a frame for session %v: %v", sb.session.id, err) - continue - } - - stream := sb.session.getStream(frame.StreamID, frame.Closing == 1) - // if the frame is telling us to close a closed stream - // (this happens when ss-server and ss-local closes the stream - // simutaneously), we don't do anything - if stream != nil { - stream.sorter.writeNewFrame(frame) - } + sb.session.recvDataFromRemote(buf[:n]) } }