diff --git a/go.mod b/go.mod index 114c723..fc22906 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,8 @@ go 1.24.1 require ( github.com/coder/websocket v1.8.12 + golang.org/x/crypto v0.37.0 golang.org/x/time v0.11.0 ) -require github.com/google/uuid v1.6.0 // indirect +require github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 4d1c035..70da848 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,7 @@ github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NA github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= diff --git a/pkg/interceptor/encrypt/encryption.go b/pkg/interceptor/encrypt/encryption.go index caf6a6b..df09e75 100644 --- a/pkg/interceptor/encrypt/encryption.go +++ b/pkg/interceptor/encrypt/encryption.go @@ -12,7 +12,7 @@ import ( ) type encryptor interface { - SetKey(key []byte) error + SetKey(encKey, decKey, sessionID []byte) error Encrypt(string, string, message.Message) (*Encrypted, error) Decrypt(*Encrypted) error Close() error @@ -20,27 +20,38 @@ type encryptor interface { type aes256 struct { encryptor cipher.AEAD + decryptor cipher.AEAD sessionID []byte mux sync.RWMutex } -func (a *aes256) SetKey(key []byte) error { - block, err := aes.NewCipher(key) - if err != nil { - return err - } +func (a *aes256) SetKey(encKey, decKey, sessionID []byte) error { + { + block, err := aes.NewCipher(encKey) + if err != nil { + return err + } - gcm, err := cipher.NewGCM(block) - if err != nil { - return err - } + gcm, err := cipher.NewGCM(block) + if err != nil { + return err + } - sessionID := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, sessionID); err != nil { - return err + a.encryptor = gcm } + { + block, err := aes.NewCipher(decKey) + if err != nil { + return err + } - a.encryptor = gcm + gcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + + a.decryptor = gcm + } a.sessionID = sessionID return nil @@ -79,7 +90,7 @@ func (a *aes256) Decrypt(m *Encrypted) error { a.mux.Lock() defer a.mux.Unlock() - data, err := a.encryptor.Open(nil, m.Nonce, m.Payload, a.sessionID) + data, err := a.decryptor.Open(nil, m.Nonce, m.Payload, a.sessionID) if err != nil { return err } diff --git a/pkg/interceptor/encrypt/interceptor.go b/pkg/interceptor/encrypt/interceptor.go index 8444510..88d91e2 100644 --- a/pkg/interceptor/encrypt/interceptor.go +++ b/pkg/interceptor/encrypt/interceptor.go @@ -2,23 +2,31 @@ package encrypt import ( "context" + "crypto/rand" + "crypto/sha256" "errors" "fmt" + "io" + "os" "sync" "github.com/coder/websocket" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/hkdf" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" "github.com/harshabose/skyline_sonata/serve/pkg/message" ) -var ServerPubKey []byte = nil +var ServerPubKey = []byte(os.Getenv("SERVER_ENCRYPT_PUB_KEY")) type Interceptor struct { interceptor.NoOpInterceptor - states map[interceptor.Connection]*state - mux sync.Mutex - ctx context.Context + states map[interceptor.Connection]*state + signKey []byte + mux sync.Mutex + ctx context.Context } func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) error { @@ -129,6 +137,49 @@ func (i *Interceptor) Close() error { return nil } -func (i *Interceptor) exchangeKeys() { +func (i *Interceptor) exchangeKeys(connection interceptor.Connection) error { + var privKey [32]byte + var pubKey [32]byte + if _, err := io.ReadFull(rand.Reader, privKey[:]); err != nil { + return err + } + + curve25519.ScalarBaseMult(&pubKey, &privKey) + + salt := make([]byte, 16) + + if _, err := io.ReadFull(rand.Reader, salt[:]); err != nil { + return err + } + + signature := append(pubKey[:], salt...) + sign := ed25519.Sign(i.signKey, signature) + + state, exists := i.states[connection] + if !exists { + return errors.New("connection not registered") + } + + state.pubKey = pubKey[:] + state.privKey = privKey[:] + state.salt = salt + + return state.writer.Write(connection, websocket.MessageText, CreateEncryptionInit(i.ID, state.peerID, pubKey[:], sign, salt)) +} + +func derive(shared, salt []byte, info string) (encKey, decKey []byte, err error) { + hkdfReader := hkdf.New(sha256.New, shared, salt, []byte(info)) + + encKey = make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, encKey); err != nil { + return nil, nil, err + } + + decKey = make([]byte, 32) + if _, err := io.ReadFull(hkdfReader, decKey); err != nil { + return nil, nil, err + } + + return encKey, decKey, nil } diff --git a/pkg/interceptor/encrypt/messages.go b/pkg/interceptor/encrypt/messages.go index 375346f..04431bb 100644 --- a/pkg/interceptor/encrypt/messages.go +++ b/pkg/interceptor/encrypt/messages.go @@ -1,9 +1,15 @@ package encrypt import ( + "crypto/rand" "errors" + "io" "time" + "github.com/coder/websocket" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" "github.com/harshabose/skyline_sonata/serve/pkg/message" ) @@ -24,7 +30,7 @@ var ProtocolEncrypt message.Protocol = "encrypt" func (payload *Encrypted) Validate() error { if payload.Nonce == nil || len(payload.Nonce) <= 0 { - return errors.New("not valid") + return message.ErrorNotValid } return payload.BaseMessage.Validate() @@ -51,5 +57,125 @@ func (payload *Encrypted) Protocol() message.Protocol { return ProtocolEncrypt } -type KeyInit struct { +type EncryptionInit struct { + message.BaseMessage + PublicKey []byte `json:"public_key"` + Signature []byte `json:"signature"` + Salt []byte `json:"salt"` +} + +func CreateEncryptionInit(senderID, receiverID string, pubKey, Sig, Salt []byte) *EncryptionInit { + return &EncryptionInit{ + BaseMessage: message.BaseMessage{ + Header: message.Header{ + SenderID: senderID, + ReceiverID: receiverID, + Protocol: message.NoneProtocol, + }, + Payload: nil, + }, + PublicKey: pubKey, + Signature: Sig, + Salt: Salt, + } +} + +func (payload *EncryptionInit) Validate() error { + if len(payload.PublicKey) == 0 && len(payload.Signature) == 0 && len(payload.Salt) == 0 { + return message.ErrorNotValid + } + return payload.BaseMessage.Validate() +} + +func (payload *EncryptionInit) Process(_interceptor interceptor.Interceptor, connection interceptor.Connection) error { + i, ok := _interceptor.(*Interceptor) + if !ok { + return errors.New("invalid interceptor") + } + + signature := append(payload.PublicKey, payload.Salt...) + if ok := ed25519.Verify(ServerPubKey, signature, payload.Signature); !ok { + return errors.New("signature did not match") + } + + var privKey [32]byte + var pubKey [32]byte + + if _, err := io.ReadFull(rand.Reader, privKey[:]); err != nil { + return err + } + + curve25519.ScalarBaseMult(&pubKey, &privKey) + + state, exists := i.states[connection] + if !exists { + return errors.New("connection not registered") + } + state.peerID = payload.SenderID + state.pubKey = pubKey[:] + state.privKey = privKey[:] + state.sessionID = nil // TODO + state.salt = payload.Salt + + encKey, decKey, err := derive(nil, state.salt, i.ID) + if err != nil { + return err + } + + if err := state.encryptor.SetKey(encKey, decKey, state.sessionID); err != nil { + return err + } + + return state.writer.Write(connection, websocket.MessageText, CreateEncryptionResponse(i.ID, state.peerID, pubKey[:])) +} + +type EncryptionResponse struct { + message.BaseMessage + PublicKey []byte `json:"public_key"` +} + +func CreateEncryptionResponse(senderID, receiverID string, pub []byte) *EncryptionResponse { + return &EncryptionResponse{ + BaseMessage: message.BaseMessage{ + Header: message.Header{ + SenderID: senderID, + ReceiverID: receiverID, + Protocol: message.NoneProtocol, + }, + Payload: nil, + }, + PublicKey: pub, + } +} + +func (payload *EncryptionResponse) Validate() error { + if payload.PublicKey == nil || len(payload.PublicKey) == 0 { + return message.ErrorNotValid + } + + return payload.BaseMessage.Validate() +} + +func (payload *EncryptionResponse) Process(_interceptor interceptor.Interceptor, connection interceptor.Connection) error { + i, ok := _interceptor.(*Interceptor) + if !ok { + return errors.New("invalid interceptor") + } + + state, exists := i.states[connection] + if !exists { + return errors.New("connection not registered") + } + + shared, err := curve25519.X25519(state.privKey, payload.PublicKey) + if err != nil { + return err + } + + encKey, decKey, err := derive(shared, state.salt, i.ID) + if err != nil { + return err + } + + return state.encryptor.SetKey(encKey, decKey, state.sessionID) } diff --git a/pkg/interceptor/encrypt/state.go b/pkg/interceptor/encrypt/state.go index 4973b57..36852aa 100644 --- a/pkg/interceptor/encrypt/state.go +++ b/pkg/interceptor/encrypt/state.go @@ -11,7 +11,11 @@ type stats struct { type state struct { stats - id string + privKey []byte + pubKey []byte + salt []byte + sessionID []byte + peerID string encryptor encryptor writer interceptor.Writer reader interceptor.Reader