general commit

This commit is contained in:
harshabose
2025-04-04 12:49:36 +05:30
parent 553b0fe3b9
commit 4f32edab5e
27 changed files with 531 additions and 961 deletions
+103
View File
@@ -0,0 +1,103 @@
# WebSocket Interceptor Framework
A flexible, extensible WebSocket middleware framework for building secure, scalable real-time communication systems.
## Overview
This framework provides an interceptor pattern implementation for WebSocket connections, allowing you to add middleware-like functionality to your WebSocket applications. Rather than the traditional approach of building communication mechanisms first and adding middleware later, this framework inverts that pattern by starting with interceptors that can work with any underlying communication stack.
## Key Features
- **Middleware for WebSockets**: Add cross-cutting concerns like encryption, authentication, logging, and compression to WebSocket connections
- **Protocol-Based Message Routing**: Nest messages with protocol identifiers for flexible routing
- **Connection Lifecycle Management**: Proper binding and cleanup of resources
- **Transport Agnostic**: Works with any communication system that supports read/write operations
- **Composable Architecture**: Chain interceptors together to build complex functionality
- **Extensible Design**: Easy to add new interceptors without modifying existing code
## Use Cases
- **Secure Communications**: Add encryption layers to WebSocket traffic
- **Real-Time Monitoring**: Log and analyze message patterns
- **Protocol Translation**: Adapt between different message formats
- **Access Control**: Implement authentication and authorization
- **Rate Limiting**: Protect your system from excessive traffic
- **Message Transformation**: Compress, validate, or transform messages in transit
## Architecture
The framework is built around several key interfaces:
- **Interceptor**: The core interface that all interceptors implement
- **Connection**: Represents a WebSocket connection
- **Writer**: Handles outgoing messages
- **Reader**: Handles incoming messages
Interceptors can be chained together to form a processing pipeline for messages. Each interceptor can examine, modify, or route messages based on their protocol and content.
## Message Structure
Messages use a flexible nested structure:
```go
type BaseMessage struct {
Header
Payload json.RawMessage `json:"payload"`
}
type Header struct {
SenderID string `json:"source_id"`
ReceiverID string `json:"destination_id"`
Protocol Protocol `json:"protocol"`
}
```
The Protocol field identifies the type of message and determines how the Payload should be processed. Messages can be nested to arbitrary depth, with each layer having its own protocol identifier.
## Example: Encryption Interceptor
An encryption interceptor can seamlessly add security to your WebSocket communications:
```go
func (i *EncryptionInterceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, m message.Message) error {
// Check if this connection has encryption enabled
state, exists := i.getState(conn)
if !exists {
return writer.Write(conn, messageType, m)
}
// Encrypt the message
encrypted, err := state.encryptor.Encrypt(m.Message().SenderID, m.Message().ReceiverID, m)
if err != nil {
return writer.Write(conn, messageType, m)
}
// Send the encrypted message
return writer.Write(conn, messageType, encrypted)
})
}
```
## Benefits Over Traditional Approaches
This "interceptors-first" approach offers several advantages:
1. **Separation of Concerns**: Clean separation between communication mechanics and business logic
2. **Framework Agnosticism**: Swap underlying communication technology without changing application code
3. **Easier Testing**: Test each interceptor in isolation
4. **Adaptability**: Add new functionality (encryption, logging, etc.) without modifying existing code
5. **Protocol Evolution**: Change protocols without widespread codebase changes
6. **Reduced Technical Debt**: Keep cross-cutting concerns localized to specific interceptors
## Getting Started
[Installation and basic usage instructions would go here]
## Example Usage
[Code examples showing how to set up and use interceptors would go here]
## License
[License information would go here]
+6
View File
@@ -1,5 +1,11 @@
package interceptor
import (
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
// Chain implements the Interceptor interface by combining multiple interceptors
// into a sequential processing pipeline. Each interceptor in the chain gets a chance
// to process the connection, reader, and writer in the order they were added.
+18 -24
View File
@@ -13,53 +13,47 @@ import (
)
type encryptor interface {
Encrypt(message.Message) (*interceptor.BaseMessage, error)
Decrypt(*interceptor.BaseMessage) (message.Message, error)
SetKey(key []byte) error
Encrypt(string, string, message.Message) (*message.BaseMessage, error)
Decrypt(*Encrypted) (message.Message, error)
Close() error
}
type aes256 struct {
key []byte
encryptor cipher.AEAD
sessionID []byte
mux sync.RWMutex
}
func createAES256() (*aes256, error) {
key := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, err
}
func (a *aes256) SetKey(key []byte) error {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
return err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
return err
}
sessionID := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, sessionID); err != nil {
return nil, err
return err
}
return &aes256{
key: key,
encryptor: gcm,
sessionID: sessionID,
}, nil
a.encryptor = gcm
a.sessionID = sessionID
return nil
}
func (a *aes256) Encrypt(senderID, receiverID string, message message.Message) (*interceptor.BaseMessage, error) {
func (a *aes256) Encrypt(senderID, receiverID string, m message.Message) (*message.BaseMessage, error) {
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
data, err := message.Marshal()
data, err := m.Marshal()
if err != nil {
return nil, err
}
@@ -68,19 +62,19 @@ func (a *aes256) Encrypt(senderID, receiverID string, message message.Message) (
defer a.mux.Unlock()
encryptedData := a.encryptor.Seal(nil, nonce, data, a.sessionID)
payload := &Encrypted{Data: encryptedData, Nonce: nonce, Timestamp: time.Now()}
return CreateMessage(senderID, receiverID, payload)
return message.CreateMessage(senderID, receiverID, NewEncrypt(senderID, receiverID, encryptedData, nonce))
}
func (a *aes256) Decrypt(message *Encrypted) (message.Message, error) {
_, err := a.encryptor.Open(nil, message.Nonce, message.Data, a.sessionID)
a.mux.Lock()
defer a.mux.Unlock()
data, err := a.encryptor.Open(nil, message.Nonce, message.Data, a.sessionID)
if err != nil {
return nil, err
}
// TODO: figure out how to create a message here
return nil, nil
}
+79 -125
View File
@@ -2,14 +2,9 @@ package encrypt
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/coder/websocket"
@@ -24,162 +19,121 @@ type Interceptor struct {
ctx context.Context
}
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) error {
i.mux.Lock()
defer i.mux.Unlock()
_, exists := i.states[connection]
if exists {
return errors.New("connection already exists")
}
ctx, cancel := context.WithCancel(i.Ctx)
i.states[connection] = &state{
id: "unknown",
encryptor: &aes256{},
writer: writer,
reader: reader,
cancel: cancel,
ctx: ctx,
}
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
// TODO: Store different keys for encryption and decryption
go encrypt.loop(connection)
return nil
}
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(connection interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
state, exists := encrypt.states[connection]
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
/*
Takes in any type of message.Message and encrypts it. In general, all implementations of
message.Message should use message.BaseMessage to implement message.Message.
*/
return interceptor.WriterFunc(func(connection interceptor.Connection, messageType websocket.MessageType, m message.Message) error {
i.mux.Lock()
defer i.mux.Unlock()
state, exists := i.states[connection]
if !exists {
return
}
})
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
encrypt.mux.Lock()
collection, exists := encrypt.states[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil {
// No encryption configured for this connection yet
// Pass through unencrypted
return writer.Write(conn, messageType, data)
return writer.Write(connection, messageType, m)
}
// Generate a nonce for this message
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return fmt.Errorf("failed to generate nonce: %w", err)
msg, err := state.encryptor.Encrypt(m.Message().SenderID, m.Message().ReceiverID, m)
if err != nil {
return writer.Write(connection, messageType, m)
}
// Encrypt the data
encryptor := collection.encryptor
sessionID := collection.sessionID
encryptedData := encryptor.Seal(nil, nonce, data, sessionID)
// Format the encrypted message:
// [2-byte nonce length][nonce][encrypted data]
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
copy(finalData[2:], nonce)
copy(finalData[2+len(nonce):], encryptedData)
// Send the encrypted message
return writer.Write(conn, messageType, finalData)
return writer.Write(connection, messageType, msg)
})
}
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
// Read the encrypted message
messageType, encryptedData, err := reader.Read(conn)
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(connection interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
i.mux.Lock()
defer i.mux.Unlock()
messageType, message, err = reader.Read(connection)
if err != nil {
return messageType, encryptedData, err
return messageType, message, err
}
encrypt.mux.Lock()
collection, exists := encrypt.states[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil || len(encryptedData) < 2 {
// No decryption configured or data too short to be encrypted
// Pass through as-is
return messageType, encryptedData, nil
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
return messageType, message, nil
}
// Extract nonce length
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
// Ensure we have enough data for the nonce and at least some ciphertext
if len(encryptedData) < int(2+nonceLen) {
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
}
// Extract nonce and ciphertext
nonce := encryptedData[2 : 2+nonceLen]
ciphertext := encryptedData[2+nonceLen:]
// Decrypt the data
decryptor := collection.encryptor
sessionID := collection.sessionID
plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID)
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
return messageType, message, err
}
return messageType, plaintext, nil
state, exists := i.states[connection]
if !exists {
return messageType, message, nil
}
if err := payload.Process(msg.BaseMessage.Header, i, connection); err != nil {
fmt.Println("error while processing encryptor message:", err.Error())
}
p, ok := payload.(*Encrypted)
if !ok {
return messageType, message, nil
}
message, err = state.encryptor.Decrypt(p)
if err != nil {
return messageType, message, nil
}
return messageType, message, nil
})
}
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
i.mux.Lock()
defer i.mux.Unlock()
collection, exists := encrypt.states[connection]
state, exists := i.states[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
collection.cancel()
delete(encrypt.states, connection)
state.cancel()
delete(i.states, connection)
}
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
func (i *Interceptor) UnInterceptSocketWriter(_ interceptor.Writer) {}
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
func (i *Interceptor) UnInterceptSocketReader(_ interceptor.Reader) {}
func (encrypt *Interceptor) Close() error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
func (i *Interceptor) Close() error {
i.mux.Lock()
defer i.mux.Unlock()
encrypt.states = make(map[*websocket.Conn]*collection)
i.states = make(map[interceptor.Connection]*state)
return nil
}
func (encrypt *Interceptor) loop(connection *websocket.Conn) {
encrypt.mux.Lock()
collection, exists := encrypt.states[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
ctx := collection.ctx
encrypt.mux.Unlock()
timer := time.NewTicker(5 * time.Minute)
defer timer.Stop()
for {
select {
case <-timer.C:
encrypt.mux.Lock()
newSessionID := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil {
fmt.Println("Failed to generate new session messageID:", err)
continue
}
if collection, exists := encrypt.states[connection]; exists {
collection.sessionID = nil // Keep nil until sending to peer mechanism is set
}
// send the update sessionID to peer
encrypt.mux.Unlock()
case <-ctx.Done():
return
}
}
}
+46 -30
View File
@@ -13,43 +13,46 @@ var (
MainType interceptor.MainType = "encrypt"
EncryptedSubType interceptor.SubType = "encrypted"
subTypeMap = map[interceptor.SubType]interceptor.Payload{
EncryptedSubType: &Encrypted{},
}
)
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
if err != nil {
return nil, err
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
if payload, exists := subTypeMap[sub]; exists {
if err := payload.Unmarshal(p); err != nil {
return nil, err
}
return payload, nil
}
return &interceptor.BaseMessage{
BaseMessage: message.BaseMessage{
Header: message.Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: interceptor.IProtocol,
},
Payload: data,
},
Header: interceptor.Header{
MainType: MainType,
SubType: payload.Type(),
},
}, nil
return nil, errors.New("processor does not exist for given type")
}
type Encrypted struct {
message.BaseMessage
Data []byte `json:"data"`
Nonce []byte `json:"nonce"`
Timestamp time.Time `json:"timestamp"`
}
func (payload *Encrypted) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
var Protocol message.Protocol = "encrypt"
func (payload *Encrypted) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
func NewEncrypt(senderID, receiverID string, data, nonce []byte) *Encrypted {
return &Encrypted{
BaseMessage: message.BaseMessage{
Header: message.Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: message.NoneProtocol,
},
Payload: nil,
},
Data: data,
Nonce: nonce,
Timestamp: time.Now(),
}
}
func (payload *Encrypted) Validate() error {
@@ -57,14 +60,27 @@ func (payload *Encrypted) Validate() error {
return errors.New("not valid")
}
return nil
return payload.BaseMessage.Validate()
}
func (payload *Encrypted) Process(header interceptor.Header, i interceptor.Interceptor, connection interceptor.Connection) error {
// TODO implement me
panic("implement me")
func (payload *Encrypted) Process(_interceptor interceptor.Interceptor, connection interceptor.Connection) error {
i, ok := _interceptor.(*Interceptor)
if !ok {
return errors.New("inappropriate interceptor for the payload")
}
state, exists := i.states[connection]
if !exists {
return errors.New("connection not registered")
}
msg, err := state.encryptor.Decrypt(payload)
if err != nil {
return err
}
}
func (payload *Encrypted) Type() interceptor.SubType {
return EncryptedSubType
func (payload *Encrypted) Protocol() message.Protocol {
return Protocol
}
+11 -1
View File
@@ -1,10 +1,20 @@
package encrypt
import "context"
import (
"context"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
type stats struct {
}
type state struct {
stats
id string
encryptor encryptor
writer interceptor.Writer
reader interceptor.Reader
cancel context.CancelFunc
ctx context.Context
}
+19 -19
View File
@@ -189,23 +189,23 @@ func (interceptor *NoOpInterceptor) Close() error {
return nil
}
// Payload defines the interface for protocol message contents.
// It extends the base message.Message interface with validation and processing
// capabilities specific to the protocol. Each implementation represents
// a different message type within the protocol.
// // Payload defines the interface for protocol message contents.
// // It extends the base message.Message interface with validation and processing
// // capabilities specific to the protocol. Each implementation represents
// // a different message type within the protocol.
// //
// // Implementations must be able to validate their own content and process
// // themselves against their respective Interceptor when received.
// type Payload interface {
// Marshal() ([]byte, error)
//
// Implementations must be able to validate their own content and process
// themselves against their respective Interceptor when received.
type Payload interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
// Validate checks if the payload data is well-formed and valid
// according to the protocol requirements.
Validate() error
// Process handles the payload-specific logic when a message is received,
// updating the appropriate state in the manager for the given connection.
Process(Header, Interceptor, Connection) error
Type() SubType
}
// Unmarshal([]byte) error
// // Validate checks if the payload data is well-formed and valid
// // according to the protocol requirements.
// Validate() error
// // Process handles the payload-specific logic when a message is received,
// // updating the appropriate state in the manager for the given connection.
// Process(message.Header, Interceptor, Connection) error
//
// Type() SubType
// }
-22
View File
@@ -1,22 +0,0 @@
package interceptor
import (
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type (
SubType string
MainType string
)
type Header struct {
MainType MainType `json:"main_type"`
SubType SubType `json:"sub_type"`
}
var IProtocol message.Protocol = "interceptor"
type BaseMessage struct {
Header
message.BaseMessage
}
-1
View File
@@ -1 +0,0 @@
package ping
-139
View File
@@ -1,139 +0,0 @@
package ping
import (
"encoding/json"
"errors"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
var (
MainType interceptor.MainType = "ping-pong"
PingSubType interceptor.SubType = "ping"
PongSubType interceptor.SubType = "pong"
subTypeMap = map[interceptor.SubType]interceptor.Payload{
PingSubType: &Ping{},
PongSubType: &Pong{},
}
)
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
if payload, exists := subTypeMap[sub]; exists {
if err := payload.Unmarshal(p); err != nil {
return nil, err
}
return payload, nil
}
return nil, errors.New("processor does not exist for given type")
}
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
if err != nil {
return nil, err
}
return &interceptor.BaseMessage{
Header: interceptor.Header{
SenderID: senderID,
ReceiverID: receiverID,
MainType: MainType,
SubType: payload.Type(),
},
Payload: data,
}, nil
}
// Ping represents a connection health check message sent by the server.
// Each ping contains a unique message ID and a timestamp that can be used
// to measure round-trip time when a corresponding pong is received.
type Ping struct {
MessageID string `json:"message_id"` // Unique identifier for matching with pong
Timestamp time.Time `json:"timestamp"` // When the ping was sent
}
// Marshal serializes the ping payload into a JSON byte array.
// This is typically used when the ping is embedded in a Message.
//
// Returns:
// - The JSON-encoded ping as a byte array
// - Any error encountered during serialization
func (payload *Ping) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
// Unmarshal deserializes a JSON byte array into this ping structure.
// This processes ping data received from a websocket message.
//
// Parameters:
// - data: The JSON-encoded ping as a byte array
//
// Returns:
// - Any error encountered during deserialization
func (payload *Ping) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
}
// Validate checks if the ping message contains valid data.
// Currently, this is a placeholder for future validation logic.
// Future implementations could validate the message ID format and
// ensure the timestamp is within an acceptable range.
//
// Returns:
// - An error if validation fails, nil otherwise
func (payload *Ping) Validate() error {
return nil
}
func (payload *Ping) Type() interceptor.SubType {
return PingSubType
}
// Pong represents a response to a ping message, confirming connection health.
// It contains the original ping's message ID and timestamp, plus its own timestamp,
// allowing the server to calculate the round-trip time.
type Pong struct {
MessageID string `json:"message_id"` // Matches the corresponding ping's ID
PingTimestamp time.Time `json:"ping_timestamp"` // When the original ping was sent
Timestamp time.Time `json:"timestamp"` // When this pong was generated
}
// Marshal serializes the pong payload into a JSON byte array.
// This is typically used when the pong is embedded in a Message.
//
// Returns:
// - The JSON-encoded pong as a byte array
// - Any error encountered during serialization
func (payload *Pong) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
// Unmarshal deserializes a JSON byte array into this pong structure.
// This processes pong data received from a websocket message.
//
// Parameters:
// - data: The JSON-encoded pong as a byte array
//
// Returns:
// - Any error encountered during deserialization
func (payload *Pong) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
}
// Validate checks if the pong message contains valid data.
// Currently, this is a placeholder for future validation logic.
// Future implementations could validate the message ID format and
// ensure the timestamps are within acceptable ranges.
//
// Returns:
// - An error if validation fails, nil otherwise
func (payload *Pong) Validate() error {
return nil
}
func (payload *Pong) Type() interceptor.SubType {
return PongSubType
}
@@ -1,4 +1,4 @@
package ping
package pingpong
import (
"context"
@@ -21,8 +21,9 @@ type InterceptorFactory struct {
// WithInterval creates an option that sets the ping message interval.
// This controls how frequently the interceptor sends ping messages to
// connected clients to verify connection health.
//
// connected clients to verify connection health. This starts a constant
// ping loop for new connection; thus use only when this interceptor
// needs to send pings.
// Parameters:
// - interval: Time duration between ping messages
//
@@ -31,6 +32,7 @@ type InterceptorFactory struct {
func WithInterval(interval time.Duration) Option {
return func(interceptor *Interceptor) error {
interceptor.interval = interval
interceptor.ping = true
return nil
}
}
@@ -84,7 +86,9 @@ func (factory *InterceptorFactory) NewInterceptor(ctx context.Context, id string
ID: id,
Ctx: ctx,
},
states: make(map[interceptor.Connection]*state),
states: make(map[interceptor.Connection]*state),
interval: time.Duration(0),
ping: false,
}
for _, option := range factory.opts {
@@ -1,4 +1,4 @@
package ping
package pingpong
import (
"context"
@@ -1,4 +1,4 @@
package ping
package pingpong
import (
"context"
@@ -7,7 +7,6 @@ import (
"time"
"github.com/coder/websocket"
"github.com/google/uuid"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
@@ -18,6 +17,7 @@ type Interceptor struct {
states map[interceptor.Connection]*state
maxHistory uint16
interval time.Duration // Time between ping messages
ping bool
}
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) error {
@@ -26,7 +26,7 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
_, exists := i.states[connection]
if exists {
return errors.New("owner already exists")
return errors.New("connection already exists")
}
ctx, cancel := context.WithCancel(i.Ctx)
@@ -42,63 +42,59 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
cancel: cancel,
}
go i.loop(ctx, i.interval, connection)
if i.ping {
go i.loop(ctx, i.interval, connection)
}
return nil
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, m message.Message) error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
return writer.Write(conn, messageType, message)
if _, exists := i.states[conn]; !exists {
return writer.Write(conn, messageType, m)
}
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
payload, err := ProtocolUnmarshal(m.Message().Header.Protocol, m.Message().Payload)
if err != nil {
return writer.Write(conn, messageType, message)
return writer.Write(conn, messageType, m)
}
if _, exists := i.states[conn]; exists {
if err := payload.Process(msg.Header, i, conn); err != nil {
fmt.Println("error while processing ping pong message: ", err.Error())
}
if err := payload.Process(i, conn); err != nil {
return writer.Write(conn, messageType, m)
}
return writer.Write(conn, messageType, message)
return writer.Write(conn, messageType, m)
})
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
messageType, message, err = reader.Read(conn)
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, m message.Message, err error) {
messageType, m, err = reader.Read(conn)
if err != nil {
return messageType, message, err
return messageType, m, err
}
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
return messageType, message, nil
if _, exists := i.states[conn]; !exists {
return messageType, m, nil
}
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
payload, err := ProtocolUnmarshal(m.Message().Header.Protocol, m.Message().Payload)
if err != nil {
return messageType, message, err
return messageType, m, nil
}
if _, exists := i.states[conn]; exists {
if err := payload.Process(msg.Header, i, conn); err != nil {
fmt.Println("error while processing ping pong message: ", err.Error())
}
if err := payload.Process(i, conn); err != nil {
return messageType, m, nil
}
return messageType, message, nil
return messageType, m, nil
})
}
@@ -149,9 +145,9 @@ func (i *Interceptor) loop(ctx context.Context, interval time.Duration, connecti
continue
}
msg, err := CreateMessage(i.ID, state.peerid, &Ping{MessageID: uuid.NewString(), Timestamp: time.Now()})
msg, err := message.CreateMessage(i.ID, state.peerid, NewPing(i.ID, state.peerid))
if err != nil {
fmt.Println("error while trying to send ping:", err.Error())
continue
}
if err := state.writer.Write(connection, websocket.MessageText, msg); err != nil {
@@ -162,7 +158,7 @@ func (i *Interceptor) loop(ctx context.Context, interval time.Duration, connecti
}
}
func (payload *Ping) Process(_ interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *Ping) Process(interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
@@ -176,13 +172,22 @@ func (payload *Ping) Process(_ interceptor.Header, interceptor interceptor.Inter
if !exists {
return errors.New("connection does not exists")
}
state.peerid = payload.SenderID
state.recordPing(payload)
if !i.ping {
msg, err := message.CreateMessage(i.ID, state.peerid, NewPong(i.ID, payload))
if err != nil {
return err
}
return state.writer.Write(connection, websocket.MessageText, msg)
}
return nil
}
func (payload *Pong) Process(header interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *Pong) Process(interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
@@ -197,7 +202,7 @@ func (payload *Pong) Process(header interceptor.Header, interceptor interceptor.
return errors.New("connection does not exists")
}
state.peerid = header.SenderID
state.peerid = payload.SenderID
state.recordPong(payload)
return nil
@@ -0,0 +1 @@
package pingpong
+114
View File
@@ -0,0 +1,114 @@
package pingpong
import (
"encoding/json"
"errors"
"time"
"github.com/google/uuid"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
var (
protocolMap = map[message.Protocol]message.Message{
ProtocolPing: &Ping{},
ProtocolPong: &Pong{},
}
)
func ProtocolUnmarshal(protocol message.Protocol, data json.RawMessage) (message.Message, error) {
msg, exists := protocolMap[protocol]
if !exists {
return nil, errors.New("protocol no match")
}
if err := msg.Unmarshal(data); err != nil {
return nil, err
}
return msg, nil
}
// Ping represents a connection health check message sent by the server.
// Each ping contains a unique message ID and a timestamp that can be used
// to measure round-trip time when a corresponding pong is received.
type Ping struct {
message.BaseMessage // NOTE: EMPTY PAYLOAD
MessageID string `json:"message_id"` // Unique identifier for matching with pong
Timestamp time.Time `json:"timestamp"` // When the ping was sent
}
var ProtocolPing message.Protocol = "ping"
func NewPing(senderID, receiverID string) *Ping {
return &Ping{
BaseMessage: message.BaseMessage{
Header: message.Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: message.NoneProtocol,
},
Payload: nil,
},
MessageID: uuid.NewString(),
Timestamp: time.Now(),
}
}
// Validate checks if the ping message contains valid data.
// Currently, this is a placeholder for future validation logic.
// Future implementations could validate the message ID format and
// ensure the timestamp is within an acceptable range.
//
// Returns:
// - An error if validation fails, nil otherwise
func (payload *Ping) Validate() error {
if payload.MessageID == "" {
return message.ErrorNotValid
}
return payload.BaseMessage.Validate()
}
func (payload *Ping) Protocol() message.Protocol {
return ProtocolPing
}
// Pong represents a response to a ping message, confirming connection health.
// It contains the original ping's message ID and timestamp, plus its own timestamp,
// allowing the server to calculate the round-trip time.
type Pong struct {
message.BaseMessage // NOTE: EMPTY PAYLOAD
MessageID string `json:"message_id"` // Unique identifier for matching with pong
Timestamp time.Time `json:"timestamp"` // When the ping was sent
PingTimestamp time.Time `json:"ping_timestamp"` // When the original ping was sent
}
func NewPong(senderID string, ping *Ping) *Pong {
return &Pong{
BaseMessage: message.BaseMessage{
Header: message.Header{
SenderID: senderID,
ReceiverID: ping.SenderID,
Protocol: message.NoneProtocol,
},
Payload: nil,
},
MessageID: ping.MessageID,
Timestamp: time.Now(),
PingTimestamp: ping.Timestamp,
}
}
var ProtocolPong message.Protocol = "pong"
func (payload *Pong) Protocol() message.Protocol {
return ProtocolPong
}
func (payload *Pong) Validate() error {
if payload.MessageID == "" {
return message.ErrorNotValid
}
return payload.BaseMessage.Validate()
}
@@ -1,4 +1,4 @@
package ping
package pingpong
import (
_ "bytes"
@@ -1,4 +1,4 @@
package ping
package pingpong
import (
"context"
-80
View File
@@ -1,80 +0,0 @@
package pong
import (
"context"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
// Option defines a function type that configures an Interceptor instance.
// Each option modifies a specific aspect of the interceptor's behavior
// and returns an error if the configuration cannot be applied.
type Option = func(*Interceptor) error
// InterceptorFactory creates ping interceptors with a predefined set of options.
// It implements the interceptor.Factory interface, allowing it to be registered
// with the interceptor registry for automatic interceptor creation.
type InterceptorFactory struct {
opts []Option // Collection of configuration options to apply
}
// WithMaxHistory creates an option that sets the maximum number of ping/pong
// records to keep in history. This limits memory usage while still allowing
// for statistical analysis of connection performance.
//
// Parameters:
// - max: Maximum number of historical ping/pong records to maintain
//
// Returns:
// - An Option that configures history limit when applied to an interceptor
func WithMaxHistory(max uint16) Option {
return func(interceptor *Interceptor) error {
interceptor.maxHistory = max
return nil
}
}
// CreateInterceptorFactory constructs a new factory that will create ping interceptors
// with the provided options. The options are stored and applied to each new
// interceptor created by the factory.
//
// Parameters:
// - options: Variable number of options to configure created interceptors
//
// Returns:
// - A configured InterceptorFactory that will create ping interceptors
func CreateInterceptorFactory(options ...Option) *InterceptorFactory {
return &InterceptorFactory{
opts: options,
}
}
// NewInterceptor creates and configures a new ping interceptor instance.
// It initializes the base NoOpInterceptor structure, creates a ping manager,
// and applies all stored options to customize the interceptor's behavior.
// This method implements the interceptor.Factory interface.
//
// Parameters:
// - ctx: Context that controls the lifetime of the interceptor
// - id: Unique identifier for the interceptor
//
// Returns:
// - A configured ping interceptor
// - Any error encountered during interceptor creation or configuration
func (factory *InterceptorFactory) NewInterceptor(ctx context.Context, id string) (interceptor.Interceptor, error) {
pongInterceptor := &Interceptor{
NoOpInterceptor: interceptor.NoOpInterceptor{
ID: id,
Ctx: ctx,
},
states: make(map[interceptor.Connection]*state),
}
for _, option := range factory.opts {
if err := option(pongInterceptor); err != nil {
return nil, err
}
}
return pongInterceptor, nil
}
-183
View File
@@ -1,183 +0,0 @@
package pong
import (
"context"
"errors"
"fmt"
"time"
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type Interceptor struct {
interceptor.NoOpInterceptor
states map[interceptor.Connection]*state
maxHistory uint16
}
func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, writer interceptor.Writer, reader interceptor.Reader) error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
_, exists := i.states[connection]
if exists {
return errors.New("owner already exists")
}
ctx, cancel := context.WithCancel(i.Ctx)
i.states[connection] = &state{
peerid: "unknown", // unknown until first ping
writer: writer, // full-stack writer (this is different from the writer in InterceptSocketWriter)
reader: reader,
pings: make([]*ping, 0),
pongs: make([]*pong, 0),
max: i.maxHistory,
ctx: ctx,
cancel: cancel,
}
return nil
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
return writer.Write(conn, messageType, message)
}
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return writer.Write(conn, messageType, message)
}
if _, exists := i.states[conn]; exists {
if err := payload.Process(msg.Header, i, conn); err != nil {
fmt.Println("error while processing ping pong message: ", err.Error())
}
}
return writer.Write(conn, messageType, message)
})
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
messageType, message, err = reader.Read(conn)
if err != nil {
return messageType, message, err
}
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
return messageType, message, nil
}
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return messageType, message, err
}
if _, exists := i.states[conn]; exists {
if err := payload.Process(msg.Header, i, conn); err != nil {
fmt.Println("error while processing ping pong message: ", err.Error())
}
}
return messageType, message, nil
})
}
func (i *Interceptor) UnBindSocketConnection(connection interceptor.Connection) {
i.Mutex.Lock()
defer i.Mutex.Unlock()
i.states[connection].cancel()
delete(i.states, connection)
}
func (i *Interceptor) UnInterceptSocketWriter(_ interceptor.Writer) {
// If left unimplemented, NoOpInterceptor's default implementation will be used
// But, for reference, this method is implemented
}
func (i *Interceptor) UnInterceptSocketReader(_ interceptor.Reader) {
// If left unimplemented, NoOpInterceptor's default implementation will be used
// But, for reference, this method is implemented
}
func (i *Interceptor) Close() error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
for _, state := range i.states {
state.cancel()
state.reader = nil
state.writer = nil
}
i.states = make(map[interceptor.Connection]*state)
return nil
}
func (payload *Ping) Process(header interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
i, ok := interceptor.(*Interceptor)
if !ok {
return errors.New("not appropriate interceptor to process this message")
}
i.Mutex.Lock()
defer i.Mutex.Unlock()
state, exists := i.states[connection]
if !exists {
return errors.New("connection does not exists")
}
state.peerid = header.SenderID
state.recordPing(payload)
msg, err := CreateMessage(i.ID, state.peerid, &Pong{MessageID: payload.MessageID, PingTimestamp: payload.Timestamp, Timestamp: time.Now()})
if err != nil {
return err
}
return state.writer.Write(connection, websocket.MessageText, msg)
}
func (payload *Pong) Process(_ interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
i, ok := interceptor.(*Interceptor)
if !ok {
return errors.New("not appropriate interceptor to process this message")
}
i.Mutex.Lock()
defer i.Mutex.Unlock()
state, exists := i.states[connection]
if !exists {
return errors.New("connection does not exists")
}
state.recordPong(payload)
return nil
}
-141
View File
@@ -1,141 +0,0 @@
package pong
import (
"encoding/json"
"errors"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
var (
MainType interceptor.MainType = "ping-pong"
PingSubType interceptor.SubType = "ping"
PongSubType interceptor.SubType = "pong"
subTypeMap = map[interceptor.SubType]interceptor.Payload{
PingSubType: &Ping{},
PongSubType: &Pong{},
}
)
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
if payload, exists := subTypeMap[sub]; exists {
if err := payload.Unmarshal(p); err != nil {
return nil, err
}
return payload, nil
}
return nil, errors.New("processor does not exist for given type")
}
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
if err != nil {
return nil, err
}
return &interceptor.BaseMessage{
Header: interceptor.Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: interceptor.IProtocol,
MainType: MainType,
SubType: payload.Type(),
},
Payload: data,
}, nil
}
// Ping represents a connection health check message sent by the server.
// Each ping contains a unique message ID and a timestamp that can be used
// to measure round-trip time when a corresponding pong is received.
type Ping struct {
MessageID string `json:"message_id"` // Unique identifier for matching with pong
Timestamp time.Time `json:"timestamp"` // When the ping was sent
}
// Marshal serializes the ping payload into a JSON byte array.
// This is typically used when the ping is embedded in a Message.
//
// Returns:
// - The JSON-encoded ping as a byte array
// - Any error encountered during serialization
func (payload *Ping) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
// Unmarshal deserializes a JSON byte array into this ping structure.
// This processes ping data received from a websocket message.
//
// Parameters:
// - data: The JSON-encoded ping as a byte array
//
// Returns:
// - Any error encountered during deserialization
func (payload *Ping) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
}
// Validate checks if the ping message contains valid data.
// Currently, this is a placeholder for future validation logic.
// Future implementations could validate the message ID format and
// ensure the timestamp is within an acceptable range.
//
// Returns:
// - An error if validation fails, nil otherwise
func (payload *Ping) Validate() error {
return nil
}
func (payload *Ping) Type() interceptor.SubType {
return "ping"
}
// Pong represents a response to a ping message, confirming connection health.
// It contains the original ping's message ID and timestamp, plus its own timestamp,
// allowing the server to calculate the round-trip time.
type Pong struct {
MessageID string `json:"message_id"` // Matches the corresponding ping's ID
PingTimestamp time.Time `json:"ping_timestamp"` // When the original ping was sent
Timestamp time.Time `json:"timestamp"` // When this pong was generated
}
// Marshal serializes the pong payload into a JSON byte array.
// This is typically used when the pong is embedded in a Message.
//
// Returns:
// - The JSON-encoded pong as a byte array
// - Any error encountered during serialization
func (payload *Pong) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
// Unmarshal deserializes a JSON byte array into this pong structure.
// This processes pong data received from a websocket message.
//
// Parameters:
// - data: The JSON-encoded pong as a byte array
//
// Returns:
// - Any error encountered during deserialization
func (payload *Pong) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
}
// Validate checks if the pong message contains valid data.
// Currently, this is a placeholder for future validation logic.
// Future implementations could validate the message ID format and
// ensure the timestamps are within acceptable ranges.
//
// Returns:
// - An error if validation fails, nil otherwise
func (payload *Pong) Validate() error {
return nil
}
func (payload *Pong) Type() interceptor.SubType {
return "pong"
}
-1
View File
@@ -1 +0,0 @@
package pong
-141
View File
@@ -1,141 +0,0 @@
package pong
import (
"context"
"sync"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
// pong represents a single pong response record.
// It stores information about a received pong message including its unique ID,
// the calculated round-trip time, and when it was received. This data is used
// for connection health analysis and statistics.
type pong struct {
messageid string // Unique identifier matching the corresponding ping
timestamp time.Time // When this pong was received
}
// ping represents a single ping request record.
// It stores information about an already sent ping message including its unique ID
// and when it was sent. This allows for matching with corresponding pongs
// and calculating accurate round-trip times.
type ping struct {
messageid string // Unique identifier for matching with corresponding pong
timestamp time.Time // When this ping was sent
}
// recent tracks the most recently processed ping and pong messages.
// This provides quick access to the latest connection health data
// without needing to search through the full history arrays.
type recent struct {
ping *ping // Most recent ping sent
pong *pong // Most recent pong received
}
// state maintains connection-specific ping/pong tracking information.
// Each websocket connection has its own state instance that records
// ping/pong history, calculates statistics, and provides methods for
// analyzing connection health and performance.
type state struct {
peerid string
writer interceptor.Writer
reader interceptor.Reader
pongs []*pong // Historical record of pongs received
pings []*ping // Historical record of pings sent
max uint16 // Maximum number of ping/pong records to keep
recvd int // Total count of pongs received
sent int // Total count of pings sent
recent recent // Most recent ping and pong
mux sync.RWMutex // Mutex for thread-safe access to state
ctx context.Context
cancel context.CancelFunc
}
// recordPong processes a received pong message and updates the state accordingly.
// It calculates the round-trip time based on the original ping timestamp,
// records the pong in the history (maintaining the maximum history size),
// updates the recent pong reference, and increments the received count.
//
// Parameters:
// - payload: The pong message received from the client
func (state *state) recordPong(payload *Pong) {
state.mux.Lock()
defer state.mux.Unlock()
pong := &pong{
messageid: payload.MessageID,
timestamp: time.Now(),
}
state.recent.pong = pong
if uint16(len(state.pongs)) >= state.max {
if len(state.pongs) > 0 {
state.pongs = state.pongs[1:]
}
}
state.pongs = append(state.pongs, pong)
state.recvd++
}
// recordPing processes an already sent ping message and updates the state accordingly.
// It records the ping in the history (maintaining the maximum history size),
// updates the recent ping reference, and increments the already sent count.
// This is typically called when the interceptor sends a ping, but could also
// track pings from the client in bidirectional ping/pong implementations.
//
// Parameters:
// - payload: The ping message sent to the client
func (state *state) recordPing(payload *Ping) {
state.mux.Lock()
defer state.mux.Unlock()
ping := &ping{
messageid: payload.MessageID,
timestamp: payload.Timestamp,
}
state.recent.ping = ping
if uint16(len(state.pings)) >= state.max {
if len(state.pings) > 0 {
state.pings = state.pings[1:]
}
}
state.pings = append(state.pings, ping)
state.sent++
}
// GetSuccessRate returns the percentage of pings that received corresponding pongs.
// This metric helps assess connection reliability by measuring how many ping
// requests are successfully acknowledged by the client.
//
// Returns:
// - The success rate as a percentage (0-100), or zero if no pings have been sent
func (state *state) GetSuccessRate() float64 {
state.mux.RLock()
defer state.mux.RUnlock()
if state.sent == 0 {
return 0
}
return 100.0 * (1.0 - float64(state.sent-state.recvd)/float64(state.sent))
}
// cleanup releases all resources held by this state.
// It clears all ping and pong records, resets counters, and removes references
// to recent ping/pong objects. This is typically called when a connection
// is closed or when the interceptor is shutting down.
func (state *state) cleanup() {
state.mux.Lock()
defer state.mux.Unlock()
state.pings = nil
state.pongs = nil
state.max = 0
state.sent = 0
state.recvd = 0
state.recent.pong = nil
state.recent.ping = nil
}
+4 -4
View File
@@ -92,7 +92,7 @@ func (i *Interceptor) Close() error {
// ================================================================================================================== //
// ================================================================================================================== //
func (payload *CreateRoom) Process(header interceptor.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *CreateRoom) Process(header message.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
@@ -128,7 +128,7 @@ func (payload *CreateRoom) Process(header interceptor.Header, _interceptor inter
return nil
}
func (payload *JoinRoom) Process(header interceptor.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *JoinRoom) Process(header message.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
@@ -157,7 +157,7 @@ func (payload *JoinRoom) Process(header interceptor.Header, _interceptor interce
return r.add(connection, state)
}
func (payload *LeaveRoom) Process(header interceptor.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *LeaveRoom) Process(header message.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
@@ -186,7 +186,7 @@ func (payload *LeaveRoom) Process(header interceptor.Header, _interceptor interc
return r.remove(connection)
}
func (payload *ChatSource) Process(header interceptor.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *ChatSource) Process(header message.Header, _interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
+6 -5
View File
@@ -6,6 +6,7 @@ import (
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
var (
@@ -185,7 +186,7 @@ func (payload *ChatDest) Validate() error {
return nil
}
func (payload *ChatDest) Process(_ interceptor.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
func (payload *ChatDest) Process(_ message.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
@@ -215,7 +216,7 @@ func (payload *ClientJoined) Validate() error {
return nil
}
func (payload *ClientJoined) Process(_ interceptor.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
func (payload *ClientJoined) Process(_ message.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
@@ -245,7 +246,7 @@ func (payload *ClientLeft) Validate() error {
return nil
}
func (payload *ClientLeft) Process(_ interceptor.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
func (payload *ClientLeft) Process(_ message.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
@@ -270,7 +271,7 @@ func (payload *Success) Validate() error {
return nil
}
func (payload *Success) Process(_ interceptor.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
func (payload *Success) Process(_ message.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
@@ -307,7 +308,7 @@ func (payload *Error) Validate() error {
return nil
}
func (payload *Error) Process(_ interceptor.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
func (payload *Error) Process(_ message.Header, _ interceptor.Interceptor, _ interceptor.Connection) error {
return nil
}
+7
View File
@@ -0,0 +1,7 @@
package message
import "errors"
var (
ErrorNotValid = errors.New("not valid")
)
+60 -1
View File
@@ -1,12 +1,22 @@
package message
import "encoding/json"
import (
"encoding/json"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
type Protocol string
var NoneProtocol Protocol = "none"
type Message interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
Protocol() Protocol
Message() *BaseMessage
Validate() error
Process(interceptor.Interceptor, interceptor.Connection) error
}
type Header struct {
@@ -15,6 +25,14 @@ type Header struct {
Protocol Protocol `json:"protocol"`
}
func (header *Header) Validate() error {
if header.SenderID == "" || header.ReceiverID == "" || header.Protocol == "" {
return ErrorNotValid
}
return nil
}
type BaseMessage struct {
Header
Payload json.RawMessage `json:"payload,omitempty"`
@@ -27,3 +45,44 @@ func (msg *BaseMessage) Marshal() ([]byte, error) {
func (msg *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, msg)
}
func (msg *BaseMessage) Protocol() Protocol {
return NoneProtocol
}
func (msg *BaseMessage) Message() *BaseMessage {
return msg
}
func (msg *BaseMessage) Validate() error {
return msg.Header.Validate()
}
func (msg *BaseMessage) Process(interceptor.Interceptor, interceptor.Connection) error {
return nil
}
func CreateMessage(senderID, receiverID string, payload Message) (*BaseMessage, error) {
var (
data json.RawMessage = nil
protocol = NoneProtocol
err error = nil
)
if payload != nil {
data, err = payload.Marshal()
if err != nil {
return nil, err
}
protocol = payload.Protocol()
}
return &BaseMessage{
Header: Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: protocol,
},
Payload: data,
}, nil
}
+7 -3
View File
@@ -17,17 +17,21 @@ func NewMultiError() *MultiError {
}
// Add appends an error to the collection if it's not nil
func (multiErr *MultiError) Add(err error) {
func (multiErr *MultiError) Add(err error) *MultiError {
if err != nil {
multiErr.errors = append(multiErr.errors, err)
}
return multiErr
}
// AddAll appends multiple errors to the collection, ignoring nil errors
func (multiErr *MultiError) AddAll(errs ...error) {
func (multiErr *MultiError) AddAll(errs ...error) *MultiError {
for _, err := range errs {
multiErr.Add(err)
_ = multiErr.Add(err)
}
return multiErr
}
// Len returns the number of errors in the collection