Fix a race, some cleanup

This commit is contained in:
Qian Wang 2018-11-24 00:55:26 +00:00
parent 3b656c9360
commit 239647c5b2
6 changed files with 50 additions and 26 deletions

View File

@ -22,6 +22,7 @@ import (
var version string var version string
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) { func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
// TODO: auto reconnect
// The maximum size of TLS message will be 16396+12. 12 because of the stream header // The maximum size of TLS message will be 16396+12. 12 because of the stream header
// 16408 is the max TLS message size on Firefox // 16408 is the max TLS message size on Firefox
buf := make([]byte, 16396) buf := make([]byte, 16396)
@ -175,7 +176,6 @@ func main() {
sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS) sesh := mux.MakeSession(0, valve, obfs, deobfs, util.ReadTLS)
var wg sync.WaitGroup var wg sync.WaitGroup
// TODO: use sync group
for i := 0; i < sta.NumConn; i++ { for i := 0; i < sta.NumConn; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -190,6 +190,7 @@ func main() {
} }
wg.Wait() wg.Wait()
// TODO: ipv6
listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT) listener, err := net.Listen("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -1,4 +1,5 @@
{ {
"WebServerAddr":"204.79.197.200:443", "WebServerAddr":"204.79.197.200:443",
"Key":"UGUmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=" "Key":"UGUmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ=",
"AdminUID":"ugDmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ="
} }

View File

@ -56,7 +56,7 @@ func InitState(localHost, localPort, remoteHost, remotePort string, nowFunc func
// semi-colon separated value. This is for Android plugin options // semi-colon separated value. This is for Android plugin options
func ssvToJson(ssv string) (ret []byte) { func ssvToJson(ssv string) (ret []byte) {
// TODO: base64 encoded data has =. How to escape? // FIXME: base64 encoded data has =. How to escape?
unescape := func(s string) string { unescape := func(s string) string {
r := strings.Replace(s, "\\\\", "\\", -1) r := strings.Replace(s, "\\\\", "\\", -1)
r = strings.Replace(r, "\\=", "=", -1) r = strings.Replace(r, "\\=", "=", -1)

View File

@ -101,6 +101,31 @@ func (sesh *Session) isStream(id uint32) bool {
return ok return ok
} }
// If the stream has been closed and the triggering frame is a closing frame,
// we return nil
func (sesh *Session) getOrAddStream(id uint32, closingFrame bool) *Stream {
// it would have been neater to use defer Unlock(), however it gives
// non-negligable overhead and this function is performance critical
sesh.streamsM.Lock()
stream := sesh.streams[id]
if stream != nil {
sesh.streamsM.Unlock()
return stream
} else {
if closingFrame {
sesh.streamsM.Unlock()
return nil
} else {
stream = makeStream(id, sesh)
sesh.streams[id] = stream
sesh.acceptCh <- stream
log.Printf("Adding stream %v\n", id)
sesh.streamsM.Unlock()
return stream
}
}
}
func (sesh *Session) getStream(id uint32) *Stream { func (sesh *Session) getStream(id uint32) *Stream {
sesh.streamsM.RLock() sesh.streamsM.RLock()
ret := sesh.streams[id] ret := sesh.streams[id]

View File

@ -154,14 +154,12 @@ func (sb *switchboard) deplex(ce *connEnclave) {
} }
frame := sb.session.deobfs(buf[:n]) frame := sb.session.deobfs(buf[:n])
var stream *Stream stream := sb.session.getOrAddStream(frame.StreamID, frame.Closing == 1)
// FIXME: get-then-put without lock
if stream = sb.session.getStream(frame.StreamID); stream == nil {
if frame.Closing == 1 {
// if the frame is telling us to close a closed stream // if the frame is telling us to close a closed stream
// (this happens when ss-server and ss-local closes the stream // (this happens when ss-server and ss-local closes the stream
// simutaneously), we don't do anything // simutaneously), we don't do anything
continue if stream != nil {
stream.writeNewFrame(frame)
} }
//debug //debug
/* /*
@ -172,8 +170,5 @@ func (sb *switchboard) deplex(ce *connEnclave) {
sb.used[frame.StreamID] = true sb.used[frame.StreamID] = true
sb.hM.Unlock() sb.hM.Unlock()
*/ */
stream = sb.session.addStream(frame.StreamID)
}
stream.writeNewFrame(frame)
} }
} }

View File

@ -69,8 +69,9 @@ func (u *User) updateInfo(uinfo UserInfo) {
func (u *User) GetSession(sessionID uint32) *mux.Session { func (u *User) GetSession(sessionID uint32) *mux.Session {
u.sessionsM.RLock() u.sessionsM.RLock()
defer u.sessionsM.RUnlock() sesh := u.sessions[sessionID]
return u.sessions[sessionID] u.sessionsM.RUnlock()
return sesh
} }
func (u *User) PutSession(sessionID uint32, sesh *mux.Session) { func (u *User) PutSession(sessionID uint32, sesh *mux.Session) {
@ -93,13 +94,14 @@ func (u *User) DelSession(sessionID uint32) {
func (u *User) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool) { func (u *User) GetOrCreateSession(sessionID uint32, obfs func(*mux.Frame) []byte, deobfs func([]byte) *mux.Frame, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool) {
// TODO: session cap // TODO: session cap
u.sessionsM.Lock() u.sessionsM.Lock()
defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil { if sesh = u.sessions[sessionID]; sesh != nil {
u.sessionsM.Unlock()
return sesh, true return sesh, true
} else { } else {
log.Printf("Creating session %v\n", sessionID) log.Printf("Creating session %v\n", sessionID)
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead) sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, obfsedRead)
u.sessions[sessionID] = sesh u.sessions[sessionID] = sesh
u.sessionsM.Unlock()
return sesh, false return sesh, false
} }
} }