Refactor session struct's obfs fields

This commit is contained in:
Qian Wang 2019-08-02 16:37:48 +01:00
parent 1a628cb524
commit e75c713385
9 changed files with 38 additions and 72 deletions

View File

@ -116,11 +116,11 @@ func makeSession(sta *client.State) *mux.Session {
wg.Wait() wg.Wait()
sessionKey := _sessionKey.Load().([]byte) sessionKey := _sessionKey.Load().([]byte)
obfs, deobfs, err := util.GenerateObfs(sta.EncryptionMethod, sessionKey) obfuscator, err := util.GenerateObfs(sta.EncryptionMethod, sessionKey)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
sesh := mux.MakeSession(sta.SessionID, mux.UNLIMITED_VALVE, obfs, deobfs, sessionKey, util.ReadTLS) sesh := mux.MakeSession(sta.SessionID, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS)
for i := 0; i < sta.NumConn; i++ { for i := 0; i < sta.NumConn; i++ {
conn := <-connsCh conn := <-connsCh

View File

@ -119,13 +119,13 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
sessionKey := make([]byte, 32) sessionKey := make([]byte, 32)
rand.Read(sessionKey) rand.Read(sessionKey)
obfs, deobfs, err := util.GenerateObfs(encryptionMethod, sessionKey) obfuscator, err := util.GenerateObfs(encryptionMethod, sessionKey)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
goWeb() goWeb()
} }
sesh, existing, err := user.GetSession(sessionID, obfs, deobfs, sessionKey, util.ReadTLS) sesh, existing, err := user.GetSession(sessionID, obfuscator, util.ReadTLS)
if err != nil { if err != nil {
user.DelSession(sessionID) user.DelSession(sessionID)
log.Error(err) log.Error(err)
@ -151,7 +151,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
log.Error(err) log.Error(err)
return return
} }
sesh := mux.MakeSession(0, mux.UNLIMITED_VALVE, obfs, deobfs, sessionKey, util.ReadTLS) sesh := mux.MakeSession(0, mux.UNLIMITED_VALVE, obfuscator, util.ReadTLS)
sesh.AddConnection(conn) sesh.AddConnection(conn)
//TODO: Router could be nil in cnc mode //TODO: Router could be nil in cnc mode
err = http.Serve(sesh, sta.LocalAPIRouter) err = http.Serve(sesh, sta.LocalAPIRouter)

View File

@ -73,7 +73,7 @@ func ssvToJson(ssv string) (ret []byte) {
value := sp[1] value := sp[1]
// JSON doesn't like quotation marks around int // JSON doesn't like quotation marks around int
// Yes this is extremely ugly but it's still better than writing a tokeniser // Yes this is extremely ugly but it's still better than writing a tokeniser
if key == "TicketTimeHint" || key == "NumConn" { if key == "NumConn" {
ret = append(ret, []byte(`"`+key+`":`+value+`,`)...) ret = append(ret, []byte(`"`+key+`":`+value+`,`)...)
} else { } else {
ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...) ret = append(ret, []byte(`"`+key+`":"`+value+`",`)...)

View File

@ -15,17 +15,21 @@ const (
var ErrBrokenSession = errors.New("broken session") var ErrBrokenSession = errors.New("broken session")
var errRepeatSessionClosing = errors.New("trying to close a closed session") var errRepeatSessionClosing = errors.New("trying to close a closed session")
type Obfuscator struct {
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header
Obfs Obfser
// Remove TLS header, decrypt and unmarshall frames
Deobfs Deobfser
SessionKey []byte
}
type Session struct { type Session struct {
id uint32 id uint32
// Used in Stream.Write. Add multiplexing headers, encrypt and add TLS header *Obfuscator
obfs Obfser
// Remove TLS header, decrypt and unmarshall multiplexing headers
deobfs Deobfser
// This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
obfsedRead func(net.Conn, []byte) (int, error)
SessionKey []byte // This is supposed to read one TLS message, the same as GoQuiet's ReadTillDrain
unitRead func(net.Conn, []byte) (int, error)
// atomic // atomic
nextStreamID uint32 nextStreamID uint32
@ -46,14 +50,12 @@ type Session struct {
terminalMsg atomic.Value terminalMsg atomic.Value
} }
func MakeSession(id uint32, valve *Valve, obfs Obfser, deobfs Deobfser, sessionKey []byte, obfsedRead func(net.Conn, []byte) (int, error)) *Session { func MakeSession(id uint32, valve *Valve, obfuscator *Obfuscator, unitReader func(net.Conn, []byte) (int, error)) *Session {
sesh := &Session{ sesh := &Session{
id: id, id: id,
obfsedRead: obfsedRead, unitRead: unitReader,
nextStreamID: 1, nextStreamID: 1,
obfs: obfs, Obfuscator: obfuscator,
deobfs: deobfs,
SessionKey: sessionKey,
streams: make(map[uint32]*Stream), streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog), acceptCh: make(chan *Stream, acceptBacklog),
} }

View File

@ -93,7 +93,7 @@ func (s *Stream) Write(in []byte) (n int, err error) {
Payload: in, Payload: in,
} }
tlsRecord, err := s.session.obfs(f) tlsRecord, err := s.session.Obfs(f)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -136,7 +136,7 @@ func (s *Stream) Close() error {
Closing: 1, Closing: 1,
Payload: pad, Payload: pad,
} }
tlsRecord, _ := s.session.obfs(f) tlsRecord, _ := s.session.Obfs(f)
s.session.sb.send(tlsRecord) s.session.sb.send(tlsRecord)
s._close() s._close()

View File

@ -11,14 +11,10 @@ import (
) )
func setupSesh() *Session { func setupSesh() *Session {
UID := make([]byte, 16) sessionKey := make([]byte, 32)
rand.Read(UID) rand.Read(sessionKey)
tthKey := make([]byte, 32) obfuscator, _ := util.GenerateObfs(0x00, sessionKey)
rand.Read(tthKey) return MakeSession(0, UNLIMITED_VALVE, obfuscator, util.ReadTLS)
crypto := &Plain{}
obfs := MakeObfs(tthKey, crypto)
deobfs := MakeDeobfs(tthKey, crypto)
return MakeSession(0, UNLIMITED_VALVE, obfs, deobfs, util.ReadTLS)
} }
type blackhole struct { type blackhole struct {
@ -66,38 +62,3 @@ func BenchmarkStream_Write(b *testing.B) {
b.SetBytes(PAYLOAD_LEN) b.SetBytes(PAYLOAD_LEN)
} }
} }
/*
func BenchmarkStream_Write(b *testing.B) {
mc := mock_conn.NewConn()
go func(){
w := bufio.NewWriter(ioutil.Discard)
for {
_, err := w.ReadFrom(mc.Server)
if err != nil {
log.Println(err)
return
}
}
}()
sesh := setupSesh()
sesh.AddConnection(mc.Client)
testData := make([]byte,PAYLOAD_LEN)
rand.Read(testData)
stream,_ := sesh.OpenStream()
b.ResetTimer()
for i:=0;i<b.N;i++{
_,err := stream.Write(testData)
if err != nil {
b.Error(
"For","stream write",
"got",err,
)
}
b.SetBytes(PAYLOAD_LEN)
}
}
*/

View File

@ -120,12 +120,12 @@ func (sb *switchboard) closeAll() {
sb.cesM.RUnlock() sb.cesM.RUnlock()
} }
// deplex function costantly reads from a TCP connection, call deobfs and distribute it // deplex function costantly reads from a TCP connection, call Deobfs and distribute it
// to the corresponding stream // to the corresponding stream
func (sb *switchboard) deplex(ce *connEnclave) { func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 20480) buf := make([]byte, 20480)
for { for {
n, err := sb.session.obfsedRead(ce.remoteConn, buf) n, err := sb.session.unitRead(ce.remoteConn, buf)
sb.rxWait(n) sb.rxWait(n)
sb.Valve.AddRx(int64(n)) sb.Valve.AddRx(int64(n))
if err != nil { if err != nil {
@ -135,7 +135,7 @@ func (sb *switchboard) deplex(ce *connEnclave) {
return return
} }
frame, err := sb.session.deobfs(buf[:n]) frame, err := sb.session.Deobfs(buf[:n])
if err != nil { if err != nil {
log.Debugf("Failed to decrypt a frame for session %v: %v", sb.session.id, err) log.Debugf("Failed to decrypt a frame for session %v: %v", sb.session.id, err)
continue continue

View File

@ -30,7 +30,7 @@ func (u *ActiveUser) DelSession(sessionID uint32) {
u.sessionsM.Unlock() u.sessionsM.Unlock()
} }
func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.Deobfser, sessionKey []byte, obfsedRead func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) { func (u *ActiveUser) GetSession(sessionID uint32, obfuscator *mux.Obfuscator, unitReader func(net.Conn, []byte) (int, error)) (sesh *mux.Session, existing bool, err error) {
u.sessionsM.Lock() u.sessionsM.Lock()
defer u.sessionsM.Unlock() defer u.sessionsM.Unlock()
if sesh = u.sessions[sessionID]; sesh != nil { if sesh = u.sessions[sessionID]; sesh != nil {
@ -40,7 +40,7 @@ func (u *ActiveUser) GetSession(sessionID uint32, obfs mux.Obfser, deobfs mux.De
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
sesh = mux.MakeSession(sessionID, u.valve, obfs, deobfs, sessionKey, obfsedRead) sesh = mux.MakeSession(sessionID, u.valve, obfuscator, unitReader)
u.sessions[sessionID] = sesh u.sessions[sessionID] = sesh
return sesh, false, nil return sesh, false, nil
} }

View File

@ -76,7 +76,7 @@ func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) {
return return
} }
func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfs mux.Obfser, deobfs mux.Deobfser, err error) { func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfuscator *mux.Obfuscator, err error) {
var payloadCipher cipher.AEAD var payloadCipher cipher.AEAD
switch encryptionMethod { switch encryptionMethod {
case 0x00: case 0x00:
@ -97,7 +97,7 @@ func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfs mux.Obfser, de
return return
} }
default: default:
return nil, nil, errors.New("Unknown encryption method") return nil, errors.New("Unknown encryption method")
} }
headerCipher, err := aes.NewCipher(sessionKey) headerCipher, err := aes.NewCipher(sessionKey)
@ -105,8 +105,11 @@ func GenerateObfs(encryptionMethod byte, sessionKey []byte) (obfs mux.Obfser, de
return return
} }
obfs = mux.MakeObfs(headerCipher, payloadCipher) obfuscator = &mux.Obfuscator{
deobfs = mux.MakeDeobfs(headerCipher, payloadCipher) mux.MakeObfs(headerCipher, payloadCipher),
mux.MakeDeobfs(headerCipher, payloadCipher),
sessionKey,
}
return return
} }