Refactor server transport

This commit is contained in:
Andy Wang 2020-04-08 21:37:21 +01:00
parent 7bfae8accd
commit 693544659f
5 changed files with 6 additions and 20 deletions

View File

@ -73,7 +73,7 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
var sessionKey [32]byte var sessionKey [32]byte
util.CryptoRandRead(sessionKey[:]) util.CryptoRandRead(sessionKey[:])
obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey, ci.Transport.HasRecordLayer()) obfuscator, err := mux.MakeObfuscator(ci.EncryptionMethod, sessionKey)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
goWeb() goWeb()
@ -93,7 +93,6 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
seshConfig := mux.SessionConfig{ seshConfig := mux.SessionConfig{
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
UnitRead: ci.Transport.UnitReadFunc(),
} }
sesh := mux.MakeSession(0, seshConfig) sesh := mux.MakeSession(0, seshConfig)
sesh.AddConnection(preparedConn) sesh.AddConnection(preparedConn)
@ -125,7 +124,6 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
seshConfig := mux.SessionConfig{ seshConfig := mux.SessionConfig{
Obfuscator: obfuscator, Obfuscator: obfuscator,
Valve: nil, Valve: nil,
UnitRead: ci.Transport.UnitReadFunc(),
Unordered: ci.Unordered, Unordered: ci.Unordered,
} }
sesh, existing, err := user.GetSession(ci.SessionId, seshConfig) sesh, existing, err := user.GetSession(ci.SessionId, seshConfig)

View File

@ -15,9 +15,7 @@ type TLS struct{}
var ErrBadClientHello = errors.New("non (or malformed) ClientHello") var ErrBadClientHello = errors.New("non (or malformed) ClientHello")
func (TLS) String() string { return "TLS" } func (TLS) String() string { return "TLS" }
func (TLS) HasRecordLayer() bool { return true }
func (TLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS }
func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
ch, err := parseClientHello(clientHello) ch, err := parseClientHello(clientHello)
@ -40,18 +38,18 @@ func (TLS) processFirstPacket(clientHello []byte, privateKey crypto.PrivateKey)
func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Responder { func (TLS) makeResponder(clientHelloSessionId []byte, sharedSecret [32]byte) Responder {
respond := func(originalConn net.Conn, sessionKey [32]byte) (preparedConn net.Conn, err error) { respond := func(originalConn net.Conn, sessionKey [32]byte) (preparedConn net.Conn, err error) {
preparedConn = originalConn
reply, err := composeReply(clientHelloSessionId, sharedSecret, sessionKey) reply, err := composeReply(clientHelloSessionId, sharedSecret, sessionKey)
if err != nil { if err != nil {
err = fmt.Errorf("failed to compose TLS reply: %v", err) err = fmt.Errorf("failed to compose TLS reply: %v", err)
return return
} }
_, err = preparedConn.Write(reply) _, err = originalConn.Write(reply)
if err != nil { if err != nil {
err = fmt.Errorf("failed to write TLS reply: %v", err) err = fmt.Errorf("failed to write TLS reply: %v", err)
go preparedConn.Close() go originalConn.Close()
return return
} }
preparedConn = &util.TLSConn{Conn: originalConn}
return return
} }
return respond return respond

View File

@ -8,8 +8,6 @@ import (
type Responder = func(originalConn net.Conn, sessionKey [32]byte) (preparedConn net.Conn, err error) type Responder = func(originalConn net.Conn, sessionKey [32]byte) (preparedConn net.Conn, err error)
type Transport interface { type Transport interface {
HasRecordLayer() bool
UnitReadFunc() func(net.Conn, []byte) (int, error)
processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (authFragments, Responder, error)
} }

View File

@ -15,9 +15,7 @@ import (
type WebSocket struct{} type WebSocket struct{}
func (WebSocket) String() string { return "WebSocket" } func (WebSocket) String() string { return "WebSocket" }
func (WebSocket) HasRecordLayer() bool { return false }
func (WebSocket) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket }
func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) { func (WebSocket) processFirstPacket(reqPacket []byte, privateKey crypto.PrivateKey) (fragments authFragments, respond Responder, err error) {
var req *http.Request var req *http.Request

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"io" "io"
"net"
"sync" "sync"
"time" "time"
) )
@ -75,8 +74,3 @@ func (ws *WebSocketConn) SetDeadline(t time.Time) error {
} }
return nil return nil
} }
// ws unit reader
func ReadWebSocket(conn net.Conn, buffer []byte) (n int, err error) {
return conn.Read(buffer)
}