diff --git a/cmd/keygen/keygen.go b/cmd/keygen/keygen.go index 4feb8e8..b9d5b37 100644 --- a/cmd/keygen/keygen.go +++ b/cmd/keygen/keygen.go @@ -10,21 +10,38 @@ import ( var b64 = base64.StdEncoding.EncodeToString func main() { + for { + fmt.Println("1 to generate UID, 2 to generate a key pair") - UID := make([]byte, 32) - rand.Read(UID) + var sel int + _, err := fmt.Scanln(&sel) + if err != nil { + fmt.Println("Please enter a number") + continue + } + if sel != 1 && sel != 2 { + fmt.Println("Please enter 1 or 2") + continue + } - ec := ecdh.NewCurve25519ECDH() - staticPv, staticPub, _ := ec.GenerateKey(rand.Reader) - marshPub := ec.Marshal(staticPub) - marshPv := staticPv.(*[32]byte)[:] + if sel == 1 { + UID := make([]byte, 32) + rand.Read(UID) + fmt.Printf("\"UID\":\"%v\"\n", b64(UID)) + } else if sel == 2 { - fmt.Printf("USER: \n") - fmt.Printf("\"UID\":\"%v\",\n", b64(UID)) - fmt.Printf("\"PublicKey\":\"%v\"\n", b64(marshPub)) + ec := ecdh.NewCurve25519ECDH() + staticPv, staticPub, _ := ec.GenerateKey(rand.Reader) + marshPub := ec.Marshal(staticPub) + marshPv := staticPv.(*[32]byte)[:] - fmt.Println("=========================================") + fmt.Printf("USER: \n") + fmt.Printf("\"PublicKey\":\"%v\"\n", b64(marshPub)) - fmt.Printf("SERVER: \n") - fmt.Printf("\"PrivateKey\":\"%v\"\n", b64(marshPv)) + fmt.Println("=========================================") + + fmt.Printf("SERVER: \n") + fmt.Printf("\"PrivateKey\":\"%v\"\n", b64(marshPv)) + } + } } diff --git a/internal/server/usermanager/controller.go b/internal/server/usermanager/controller.go index d0f5f7a..d4854b4 100644 --- a/internal/server/usermanager/controller.go +++ b/internal/server/usermanager/controller.go @@ -47,6 +47,8 @@ func (c *controller) HandleRequest(req []byte) ([]byte, error) { if err == ErrInvalidMac { log.Printf("!!!CONTROL MESSAGE AND HMAC MISMATCH!!!\n raw request:\n%x\ndecrypted msg:\n%x", req, plain) return nil, err + } else { + return c.respond([]byte(err.Error())), nil } switch plain[0] { @@ -102,6 +104,7 @@ func (c *controller) HandleRequest(req []byte) ([]byte, error) { } var ErrInvalidMac = errors.New("Mac mismatch") +var errMsgTooShort = errors.New("Message length is less than 54") // protocol: [TLS record layer 5 bytes][IV 16 bytes][data][hmac 32 bytes] func (c *controller) respond(resp []byte) []byte { @@ -127,6 +130,9 @@ func (c *controller) respond(resp []byte) []byte { } func (c *controller) checkAndDecrypt(data []byte) ([]byte, error) { + if len(data) < 54 { + return nil, errMsgTooShort + } macIndex := len(data) - 32 mac := hmac.New(sha256.New, c.adminUID[16:32]) mac.Write(data[5:macIndex])