diff --git a/pkg/interceptor/encrypt/_interceptor.go b/pkg/interceptor/encrypt/_interceptor.go deleted file mode 100644 index 9948171..0000000 --- a/pkg/interceptor/encrypt/_interceptor.go +++ /dev/null @@ -1,216 +0,0 @@ -package encrypt - -import ( - "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/binary" - "fmt" - "io" - "sync" - "time" - - "github.com/coder/websocket" - - "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" -) - -type collection struct { - encryptor cipher.AEAD // For encrypting outgoing messages - sessionID []byte - ctx context.Context - cancel context.CancelFunc -} - -type Interceptor struct { - interceptor.NoOpInterceptor - collection map[*websocket.Conn]*collection - mux sync.Mutex - ctx context.Context -} - -func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error { - encrypt.mux.Lock() - defer encrypt.mux.Unlock() - - // Generate a key for this connection - key := make([]byte, 32) // AES-256 - if _, err := io.ReadFull(rand.Reader, key); err != nil { - return fmt.Errorf("failed to generate encryption key: %w", err) - } - - // Create AES cipher - block, err := aes.NewCipher(key) - if err != nil { - return fmt.Errorf("failed to create AES cipher: %w", err) - } - - // Create GCM mode encryptor - gcm, err := cipher.NewGCM(block) - if err != nil { - return fmt.Errorf("failed to create GCM mode: %w", err) - } - - sessionID := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, sessionID); err != nil { - fmt.Println("Failed to generate new session messageID:", err) - } - - ctx, cancel := context.WithCancel(encrypt.ctx) - encrypt.collection[connection] = &collection{ - encryptor: gcm, - sessionID: nil, - ctx: ctx, - cancel: cancel, - } - - // TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman - // TODO: Store different keys for encryption and decryption - - go encrypt.loop(connection) - - return nil -} - -func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer { - return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error { - encrypt.mux.Lock() - collection, exists := encrypt.collection[conn] - encrypt.mux.Unlock() - - if !exists || collection.encryptor == nil { - // No encryption configured for this connection yet - // Pass through unencrypted - return writer.Write(conn, messageType, data) - } - - // Generate a nonce for this message - nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return fmt.Errorf("failed to generate nonce: %w", err) - } - - // Encrypt the data - encryptor := collection.encryptor - sessionID := collection.sessionID - encryptedData := encryptor.Seal(nil, nonce, data, sessionID) - - // Format the encrypted message: - // [2-byte nonce length][nonce][encrypted data] - finalData := make([]byte, 2+len(nonce)+len(encryptedData)) - binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce))) - copy(finalData[2:], nonce) - copy(finalData[2+len(nonce):], encryptedData) - - // Send the encrypted message - return writer.Write(conn, messageType, finalData) - }) -} - -func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader { - return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) { - // Read the encrypted message - messageType, encryptedData, err := reader.Read(conn) - if err != nil { - return messageType, encryptedData, err - } - - encrypt.mux.Lock() - collection, exists := encrypt.collection[conn] - encrypt.mux.Unlock() - - if !exists || collection.encryptor == nil || len(encryptedData) < 2 { - // No decryption configured or data too short to be encrypted - // Pass through as-is - return messageType, encryptedData, nil - } - - // Extract nonce length - nonceLen := binary.BigEndian.Uint16(encryptedData[:2]) - - // Ensure we have enough data for the nonce and at least some ciphertext - if len(encryptedData) < int(2+nonceLen) { - return messageType, encryptedData, fmt.Errorf("encrypted data too short") - } - - // Extract nonce and ciphertext - nonce := encryptedData[2 : 2+nonceLen] - ciphertext := encryptedData[2+nonceLen:] - - // Decrypt the data - decryptor := collection.encryptor - sessionID := collection.sessionID - plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID) - if err != nil { - return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err) - } - - return messageType, plaintext, nil - }) -} - -func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) { - encrypt.mux.Lock() - defer encrypt.mux.Unlock() - - collection, exists := encrypt.collection[connection] - if !exists { - fmt.Println("connection does not exists") - return - } - - collection.cancel() - delete(encrypt.collection, connection) -} - -func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {} - -func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {} - -func (encrypt *Interceptor) Close() error { - encrypt.mux.Lock() - defer encrypt.mux.Unlock() - - encrypt.collection = make(map[*websocket.Conn]*collection) - - return nil -} - -func (encrypt *Interceptor) loop(connection *websocket.Conn) { - - encrypt.mux.Lock() - collection, exists := encrypt.collection[connection] - if !exists { - fmt.Println("connection does not exists") - return - } - ctx := collection.ctx - encrypt.mux.Unlock() - - timer := time.NewTicker(5 * time.Minute) - defer timer.Stop() - - for { - select { - case <-timer.C: - encrypt.mux.Lock() - - newSessionID := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil { - fmt.Println("Failed to generate new session messageID:", err) - continue - } - - if collection, exists := encrypt.collection[connection]; exists { - collection.sessionID = nil // Keep nil until sending to peer mechanism is set - } - - // send the update sessionID to peer - - encrypt.mux.Unlock() - case <-ctx.Done(): - return - } - } -} diff --git a/pkg/interceptor/encrypt/encryption.go b/pkg/interceptor/encrypt/encryption.go index 6c0ce82..c5c1ff5 100644 --- a/pkg/interceptor/encrypt/encryption.go +++ b/pkg/interceptor/encrypt/encryption.go @@ -4,40 +4,38 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" - "encoding/binary" - "errors" "io" "sync" + "time" + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" "github.com/harshabose/skyline_sonata/serve/pkg/message" ) type encryptor interface { - Encrypt(message.Message) (*Message, error) - Decrypt(*Message) (message.Message, error) + Encrypt(message.Message) (*interceptor.BaseMessage, error) + Decrypt(*interceptor.BaseMessage) (message.Message, error) Close() error } type aes256 struct { + key []byte encryptor cipher.AEAD sessionID []byte mux sync.RWMutex } func createAES256() (*aes256, error) { - // generate key key := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, key); err != nil { return nil, err } - // create cipher block block, err := aes.NewCipher(key) if err != nil { return nil, err } - // create GCM mode encryptor gcm, err := cipher.NewGCM(block) if err != nil { return nil, err @@ -49,12 +47,13 @@ func createAES256() (*aes256, error) { } return &aes256{ + key: key, encryptor: gcm, sessionID: sessionID, }, nil } -func (a *aes256) Encrypt(message message.Message) (*Message, error) { +func (a *aes256) Encrypt(senderID, receiverID string, message message.Message) (*interceptor.BaseMessage, error) { nonce := make([]byte, 12) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err @@ -69,19 +68,28 @@ func (a *aes256) Encrypt(message message.Message) (*Message, error) { defer a.mux.Unlock() encryptedData := a.encryptor.Seal(nil, nonce, data, a.sessionID) - finalData := make([]byte, 2+len(nonce)+len(encryptedData)) - binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce))) + payload := &Encrypted{Data: encryptedData, Nonce: nonce, Timestamp: time.Now()} - copy(finalData[2:], nonce) - copy(finalData[2+len(nonce):], encryptedData) - - return CreateMessage(PayloadEncryptedType, finalData), nil + return CreateMessage(senderID, receiverID, payload) } -func (a *aes256) Decrypt(message *Message) (message.Message, error) { +func (a *aes256) Decrypt(message *Encrypted) (message.Message, error) { + _, err := a.encryptor.Open(nil, message.Nonce, message.Data, a.sessionID) + if err != nil { + return nil, err + } + // TODO: figure out how to create a message here + + return nil, nil } func (a *aes256) Close() error { + a.mux.Lock() + defer a.mux.Unlock() + a.sessionID = nil + a.encryptor = nil + + return nil } diff --git a/pkg/interceptor/encrypt/interceptor.go b/pkg/interceptor/encrypt/interceptor.go index 201d65b..4dcc514 100644 --- a/pkg/interceptor/encrypt/interceptor.go +++ b/pkg/interceptor/encrypt/interceptor.go @@ -1,7 +1,185 @@ package encrypt -import "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "sync" + "time" + + "github.com/coder/websocket" + + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" +) type Interceptor struct { interceptor.NoOpInterceptor + states map[interceptor.Connection]*state + mux sync.Mutex + ctx context.Context +} + +func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + // TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman + // TODO: Store different keys for encryption and decryption + + go encrypt.loop(connection) + + return nil +} + +func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer { + return interceptor.WriterFunc(func(connection interceptor.Connection, messageType websocket.MessageType, message message.Message) error { + state, exists := encrypt.states[connection] + if !exists { + return + } + }) + return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error { + encrypt.mux.Lock() + collection, exists := encrypt.states[conn] + encrypt.mux.Unlock() + + if !exists || collection.encryptor == nil { + // No encryption configured for this connection yet + // Pass through unencrypted + return writer.Write(conn, messageType, data) + } + + // Generate a nonce for this message + nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("failed to generate nonce: %w", err) + } + + // Encrypt the data + encryptor := collection.encryptor + sessionID := collection.sessionID + encryptedData := encryptor.Seal(nil, nonce, data, sessionID) + + // Format the encrypted message: + // [2-byte nonce length][nonce][encrypted data] + finalData := make([]byte, 2+len(nonce)+len(encryptedData)) + binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce))) + copy(finalData[2:], nonce) + copy(finalData[2+len(nonce):], encryptedData) + + // Send the encrypted message + return writer.Write(conn, messageType, finalData) + }) +} + +func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader { + return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) { + // Read the encrypted message + messageType, encryptedData, err := reader.Read(conn) + if err != nil { + return messageType, encryptedData, err + } + + encrypt.mux.Lock() + collection, exists := encrypt.states[conn] + encrypt.mux.Unlock() + + if !exists || collection.encryptor == nil || len(encryptedData) < 2 { + // No decryption configured or data too short to be encrypted + // Pass through as-is + return messageType, encryptedData, nil + } + + // Extract nonce length + nonceLen := binary.BigEndian.Uint16(encryptedData[:2]) + + // Ensure we have enough data for the nonce and at least some ciphertext + if len(encryptedData) < int(2+nonceLen) { + return messageType, encryptedData, fmt.Errorf("encrypted data too short") + } + + // Extract nonce and ciphertext + nonce := encryptedData[2 : 2+nonceLen] + ciphertext := encryptedData[2+nonceLen:] + + // Decrypt the data + decryptor := collection.encryptor + sessionID := collection.sessionID + plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID) + if err != nil { + return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err) + } + + return messageType, plaintext, nil + }) +} + +func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + collection, exists := encrypt.states[connection] + if !exists { + fmt.Println("connection does not exists") + return + } + + collection.cancel() + delete(encrypt.states, connection) +} + +func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {} + +func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {} + +func (encrypt *Interceptor) Close() error { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + encrypt.states = make(map[*websocket.Conn]*collection) + + return nil +} + +func (encrypt *Interceptor) loop(connection *websocket.Conn) { + + encrypt.mux.Lock() + collection, exists := encrypt.states[connection] + if !exists { + fmt.Println("connection does not exists") + return + } + ctx := collection.ctx + encrypt.mux.Unlock() + + timer := time.NewTicker(5 * time.Minute) + defer timer.Stop() + + for { + select { + case <-timer.C: + encrypt.mux.Lock() + + newSessionID := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil { + fmt.Println("Failed to generate new session messageID:", err) + continue + } + + if collection, exists := encrypt.states[connection]; exists { + collection.sessionID = nil // Keep nil until sending to peer mechanism is set + } + + // send the update sessionID to peer + + encrypt.mux.Unlock() + case <-ctx.Done(): + return + } + } } diff --git a/pkg/interceptor/encrypt/messages.go b/pkg/interceptor/encrypt/messages.go index 7246a9a..859ed19 100644 --- a/pkg/interceptor/encrypt/messages.go +++ b/pkg/interceptor/encrypt/messages.go @@ -2,22 +2,69 @@ package encrypt import ( "encoding/json" + "errors" + "time" + + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" ) -type PayloadType string +var ( + MainType interceptor.MainType = "encrypt" -const ( - PayloadEncryptedType PayloadType = "encrypt:encrypted" + EncryptedSubType interceptor.SubType = "encrypted" ) -type Message struct { - Type PayloadType `json:"type"` - Payload json.RawMessage `json:"payload"` -} - -func CreateMessage(_type PayloadType, payload json.RawMessage) *Message { - return &Message{ - Type: _type, - Payload: payload, +func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) { + data, err := payload.Marshal() + if err != nil { + return nil, err } + + return &interceptor.BaseMessage{ + BaseMessage: message.BaseMessage{ + Header: message.Header{ + SenderID: senderID, + ReceiverID: receiverID, + Protocol: interceptor.IProtocol, + }, + Payload: data, + }, + Header: interceptor.Header{ + + MainType: MainType, + SubType: payload.Type(), + }, + }, nil +} + +type Encrypted struct { + Data []byte `json:"data"` + Nonce []byte `json:"nonce"` + Timestamp time.Time `json:"timestamp"` +} + +func (payload *Encrypted) Marshal() ([]byte, error) { + return json.Marshal(payload) +} + +func (payload *Encrypted) Unmarshal(data []byte) error { + return json.Unmarshal(data, payload) +} + +func (payload *Encrypted) Validate() error { + if payload.Data == nil || payload.Nonce == nil || len(payload.Data) <= 0 || len(payload.Nonce) <= 0 { + return errors.New("not valid") + } + + return nil +} + +func (payload *Encrypted) Process(header interceptor.Header, i interceptor.Interceptor, connection interceptor.Connection) error { + // TODO implement me + panic("implement me") +} + +func (payload *Encrypted) Type() interceptor.SubType { + return EncryptedSubType } diff --git a/pkg/interceptor/encrypt/state.go b/pkg/interceptor/encrypt/state.go index 7e5445a..2c02b48 100644 --- a/pkg/interceptor/encrypt/state.go +++ b/pkg/interceptor/encrypt/state.go @@ -1 +1,10 @@ package encrypt + +import "context" + +type state struct { + id string + encryptor encryptor + cancel context.CancelFunc + ctx context.Context +} diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go index 17e4e66..eb54591 100644 --- a/pkg/interceptor/interceptor.go +++ b/pkg/interceptor/interceptor.go @@ -6,6 +6,8 @@ import ( "sync" "github.com/coder/websocket" + + "github.com/harshabose/skyline_sonata/serve/pkg/message" ) // Registry maintains a collection of interceptor factories that can be used to @@ -110,29 +112,29 @@ type Writer interface { // Write sends a message to the connection // Takes the connection, message type, and message to write // Returns any error encountered during writing - Write(conn Connection, messageType websocket.MessageType, message Message) error + Write(conn Connection, messageType websocket.MessageType, message message.Message) error } // Reader is an interface for reading messages from a websocket connection type Reader interface { // Read reads a message from the connection // Returns the message type, message data, and any error - Read(conn Connection) (messageType websocket.MessageType, message Message, err error) + Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error) } // ReaderFunc is a function type that implements the Reader interface -type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message Message, err error) +type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message message.Message, err error) // Read implements the Reader interface for ReaderFunc -func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message Message, err error) { +func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error) { return f(conn) } // WriterFunc is a function type that implements the Writer interface -type WriterFunc func(conn Connection, messageType websocket.MessageType, message Message) error +type WriterFunc func(conn Connection, messageType websocket.MessageType, message message.Message) error // Write implements the Writer interface for WriterFunc -func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message Message) error { +func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message message.Message) error { return f(conn, messageType, message) } diff --git a/pkg/interceptor/message.go b/pkg/interceptor/message.go index 36caf20..c113695 100644 --- a/pkg/interceptor/message.go +++ b/pkg/interceptor/message.go @@ -1,37 +1,22 @@ package interceptor -import "encoding/json" - -type Message interface { - Marshal() ([]byte, error) - Unmarshal([]byte) error -} +import ( + "github.com/harshabose/skyline_sonata/serve/pkg/message" +) type ( - Protocol string SubType string MainType string ) type Header struct { - SenderID string `json:"source_id"` - ReceiverID string `json:"destination_id"` - Protocol Protocol `json:"protocol"` - MainType MainType `json:"main_type"` - SubType SubType `json:"sub_type"` + MainType MainType `json:"main_type"` + SubType SubType `json:"sub_type"` } -var IProtocol Protocol = "interceptor" +var IProtocol message.Protocol = "interceptor" -type BaseMessage struct { // This actually needs to be interceptor module and Message interface should be in its own module +type BaseMessage struct { Header - Payload json.RawMessage `json:"payload"` -} - -func (msg *BaseMessage) Marshal() ([]byte, error) { - return json.Marshal(msg) -} - -func (msg *BaseMessage) Unmarshal(data []byte) error { - return json.Unmarshal(data, msg) + message.BaseMessage } diff --git a/pkg/interceptor/ping/interceptor.go b/pkg/interceptor/ping/interceptor.go index dce8cf1..9446c8e 100644 --- a/pkg/interceptor/ping/interceptor.go +++ b/pkg/interceptor/ping/interceptor.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" ) type Interceptor struct { @@ -47,17 +48,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr } func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { - return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error { + return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error { i.Mutex.Lock() defer i.Mutex.Unlock() - msg, ok := message.(*Message) - if !ok || (msg.MainType != "ping" && msg.SubType != "ping") { + msg, ok := message.(*interceptor.BaseMessage) + if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) { return writer.Write(conn, messageType, message) } - payload := &Ping{} - if err := payload.Unmarshal(msg.Payload); err != nil { + payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) + if err != nil { return writer.Write(conn, messageType, message) } @@ -72,7 +73,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept } func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader { - return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) { + return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) { messageType, message, err = reader.Read(conn) if err != nil { return messageType, message, err @@ -81,14 +82,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept i.Mutex.Lock() defer i.Mutex.Unlock() - msg, ok := message.(*Message) - if !ok { + msg, ok := message.(*interceptor.BaseMessage) + if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) { return messageType, message, nil } - payload := &Pong{} - if err := payload.Unmarshal(msg.Payload); err != nil { - return messageType, message, nil + payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) + if err != nil { + return messageType, message, err } if _, exists := i.states[conn]; exists { @@ -161,7 +162,7 @@ func (i *Interceptor) loop(ctx context.Context, interval time.Duration, connecti } } -func (payload *Ping) Process(header interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error { +func (payload *Ping) Process(_ interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error { if err := payload.Validate(); err != nil { return err } diff --git a/pkg/interceptor/ping/messages.go b/pkg/interceptor/ping/messages.go index 01b2f7b..b8627a2 100644 --- a/pkg/interceptor/ping/messages.go +++ b/pkg/interceptor/ping/messages.go @@ -2,29 +2,48 @@ package ping import ( "encoding/json" + "errors" "time" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" ) -type Message struct { - interceptor.BaseMessage +var ( + MainType interceptor.MainType = "ping-pong" + PingSubType interceptor.SubType = "ping" + PongSubType interceptor.SubType = "pong" + + subTypeMap = map[interceptor.SubType]interceptor.Payload{ + PingSubType: &Ping{}, + PongSubType: &Pong{}, + } +) + +func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) { + if payload, exists := subTypeMap[sub]; exists { + if err := payload.Unmarshal(p); err != nil { + return nil, err + } + return payload, nil + } + + return nil, errors.New("processor does not exist for given type") } -func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*Message, error) { +func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) { data, err := payload.Marshal() if err != nil { return nil, err } - return &Message{ - BaseMessage: interceptor.BaseMessage{ - Header: interceptor.Header{ - SenderID: senderID, - ReceiverID: receiverID, - }, - Payload: data, + return &interceptor.BaseMessage{ + Header: interceptor.Header{ + SenderID: senderID, + ReceiverID: receiverID, + MainType: MainType, + SubType: payload.Type(), }, + Payload: data, }, nil } @@ -69,6 +88,10 @@ func (payload *Ping) Validate() error { return nil } +func (payload *Ping) Type() interceptor.SubType { + return PingSubType +} + // Pong represents a response to a ping message, confirming connection health. // It contains the original ping's message ID and timestamp, plus its own timestamp, // allowing the server to calculate the round-trip time. @@ -110,3 +133,7 @@ func (payload *Pong) Unmarshal(data []byte) error { func (payload *Pong) Validate() error { return nil } + +func (payload *Pong) Type() interceptor.SubType { + return PongSubType +} diff --git a/pkg/interceptor/pong/interceptor.go b/pkg/interceptor/pong/interceptor.go index e213631..8578832 100644 --- a/pkg/interceptor/pong/interceptor.go +++ b/pkg/interceptor/pong/interceptor.go @@ -9,6 +9,7 @@ import ( "github.com/coder/websocket" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" ) type Interceptor struct { @@ -43,17 +44,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr } func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer { - return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error { + return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error { i.Mutex.Lock() defer i.Mutex.Unlock() msg, ok := message.(*interceptor.BaseMessage) - if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) { + if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) { return writer.Write(conn, messageType, message) } - payload := &Pong{} - if err := payload.Unmarshal(msg.Payload); err != nil { + payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) + if err != nil { return writer.Write(conn, messageType, message) } @@ -68,7 +69,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept } func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader { - return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) { + return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) { messageType, message, err = reader.Read(conn) if err != nil { return messageType, message, err @@ -78,13 +79,13 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept defer i.Mutex.Unlock() msg, ok := message.(*interceptor.BaseMessage) - if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) { + if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) { return messageType, message, nil } - payload := &Ping{} - if err := payload.Unmarshal(msg.Payload); err != nil { - return messageType, message, nil + payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) + if err != nil { + return messageType, message, err } if _, exists := i.states[conn]; exists { diff --git a/pkg/interceptor/pong/messages.go b/pkg/interceptor/pong/messages.go index a42ff3f..355f6e5 100644 --- a/pkg/interceptor/pong/messages.go +++ b/pkg/interceptor/pong/messages.go @@ -2,12 +2,34 @@ package pong import ( "encoding/json" + "errors" "time" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" ) -var MainType interceptor.MainType = "pong" +var ( + MainType interceptor.MainType = "ping-pong" + + PingSubType interceptor.SubType = "ping" + PongSubType interceptor.SubType = "pong" + + subTypeMap = map[interceptor.SubType]interceptor.Payload{ + PingSubType: &Ping{}, + PongSubType: &Pong{}, + } +) + +func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) { + if payload, exists := subTypeMap[sub]; exists { + if err := payload.Unmarshal(p); err != nil { + return nil, err + } + return payload, nil + } + + return nil, errors.New("processor does not exist for given type") +} func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) { data, err := payload.Marshal() diff --git a/pkg/interceptor/room/interceptor.go b/pkg/interceptor/room/interceptor.go index e9f2a0c..2cfb88e 100644 --- a/pkg/interceptor/room/interceptor.go +++ b/pkg/interceptor/room/interceptor.go @@ -8,6 +8,7 @@ import ( "github.com/coder/websocket" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" ) type Interceptor struct { @@ -30,7 +31,7 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr } func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader { - return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, interceptor.Message, error) { + return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, message.Message, error) { messageType, data, err := reader.Read(connection) if err != nil { return messageType, data, err @@ -44,15 +45,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept i.Mutex.Lock() defer i.Mutex.Unlock() + payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) + if err != nil { + return messageType, data, nil + } + if _, exists := i.states[connection]; exists { - payload, err := PayloadUnmarshal(msg.SubType, msg.Payload) - if err != nil { - fmt.Println("error while processing room message: ", err.Error()) - return messageType, data, nil - } if err := payload.Process(msg.Header, i, connection); err != nil { - fmt.Println("error while processing room message: ", err.Error()) return messageType, data, nil } } diff --git a/pkg/message/message.go b/pkg/message/message.go index cf69a8c..bab7150 100644 --- a/pkg/message/message.go +++ b/pkg/message/message.go @@ -1,6 +1,29 @@ package message +import "encoding/json" + +type Protocol string + type Message interface { Marshal() ([]byte, error) Unmarshal([]byte) error } + +type Header struct { + SenderID string `json:"source_id"` + ReceiverID string `json:"destination_id"` + Protocol Protocol `json:"protocol"` +} + +type BaseMessage struct { + Header + Payload json.RawMessage `json:"payload,omitempty"` +} + +func (msg *BaseMessage) Marshal() ([]byte, error) { + return json.Marshal(msg) +} + +func (msg *BaseMessage) Unmarshal(data []byte) error { + return json.Unmarshal(data, msg) +}