diff --git a/.gitignore b/.gitignore index e69de29..ad29947 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1,2 @@ +/.idea +/third_party/ diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f8b02a2 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/harshabose/skyline_sonata/serve + +go 1.24.1 + +require ( + github.com/coder/websocket v1.8.12 + golang.org/x/time v0.11.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7173253 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= diff --git a/pkg/interceptor/auth/factory.go b/pkg/interceptor/auth/factory.go new file mode 100644 index 0000000..8832b06 --- /dev/null +++ b/pkg/interceptor/auth/factory.go @@ -0,0 +1 @@ +package auth diff --git a/pkg/interceptor/auth/interceptor.go b/pkg/interceptor/auth/interceptor.go new file mode 100644 index 0000000..4d7ded0 --- /dev/null +++ b/pkg/interceptor/auth/interceptor.go @@ -0,0 +1,19 @@ +package auth + +import ( + "github.com/coder/websocket" + + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" +) + +type Interceptor struct { + interceptor.NoOpInterceptor +} + +func (auth *Interceptor) BindConnection(connection *websocket.Conn) { + return +} + +func (auth *Interceptor) Close() error { + return nil +} diff --git a/pkg/interceptor/chain.go b/pkg/interceptor/chain.go index 1a045e9..4fe11da 100644 --- a/pkg/interceptor/chain.go +++ b/pkg/interceptor/chain.go @@ -1,5 +1,7 @@ package interceptor +import "github.com/coder/websocket" + type Chain struct { interceptors []Interceptor } @@ -8,28 +10,47 @@ func CreateChain(interceptors []Interceptor) *Chain { return &Chain{interceptors: interceptors} } -func (chain *Chain) BindIncoming(reader IncomingReader) IncomingReader { +func (chain *Chain) BindSocketConnection(connection *websocket.Conn) error { for _, interceptor := range chain.interceptors { - interceptor.BindIncoming(reader) + if err := interceptor.BindSocketConnection(connection); err != nil { + return err + } } - - return reader + return nil } -func (chain *Chain) BindOutgoing(writer OutgoingWriter) OutgoingWriter { +func (chain *Chain) BindSocketWriter(writer Writer) Writer { for _, interceptor := range chain.interceptors { - interceptor.BindOutgoing(writer) + writer = interceptor.BindSocketWriter(writer) } return writer } -func (chain *Chain) BindConnection(connection Connection) Connection { +func (chain *Chain) BindSocketReader(reader Reader) Reader { for _, interceptor := range chain.interceptors { - interceptor.BindConnection(connection) + reader = interceptor.BindSocketReader(reader) } - return connection + return reader +} + +func (chain *Chain) UnBindSocketConnection(connection *websocket.Conn) { + for _, interceptor := range chain.interceptors { + interceptor.UnBindSocketConnection(connection) + } +} + +func (chain *Chain) UnBindSocketWriter(writer Writer) { + for _, interceptor := range chain.interceptors { + interceptor.UnBindSocketWriter(writer) + } +} + +func (chain *Chain) UnBindSocketReader(reader Reader) { + for _, interceptor := range chain.interceptors { + interceptor.UnBindSocketReader(reader) + } } func (chain *Chain) Close() error { diff --git a/pkg/interceptor/encrypt/factory.go b/pkg/interceptor/encrypt/factory.go new file mode 100644 index 0000000..7e5445a --- /dev/null +++ b/pkg/interceptor/encrypt/factory.go @@ -0,0 +1 @@ +package encrypt diff --git a/pkg/interceptor/encrypt/interceptor.go b/pkg/interceptor/encrypt/interceptor.go new file mode 100644 index 0000000..1d9dbe4 --- /dev/null +++ b/pkg/interceptor/encrypt/interceptor.go @@ -0,0 +1,154 @@ +package encrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "sync" + + "github.com/coder/websocket" + + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" +) + +type collection struct { + encryptor cipher.AEAD // For encrypting outgoing messages + decryptor cipher.AEAD // For decrypting incoming messages +} + +type Interceptor struct { + interceptor.NoOpInterceptor + collection map[*websocket.Conn]*collection + mux sync.Mutex +} + +func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + // Generate a key for this connection + key := make([]byte, 32) // AES-256 + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return fmt.Errorf("failed to generate encryption key: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(key) + if err != nil { + return fmt.Errorf("failed to create AES cipher: %w", err) + } + + // Create GCM mode encryptor + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("failed to create GCM mode: %w", err) + } + _ = gcm + + // TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman + // TODO: Store different keys for encryption and decryption + + encrypt.collection[connection] = &collection{ + encryptor: gcm, + decryptor: gcm, + } + return nil +} + +func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer { + return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error { + encrypt.mux.Lock() + connState, exists := encrypt.collection[conn] + encrypt.mux.Unlock() + + if !exists || connState.encryptor == nil { + // No encryption configured for this connection yet + // Pass through unencrypted + return writer.Write(conn, messageType, data) + } + + // Generate a nonce for this message + nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("failed to generate nonce: %w", err) + } + + // Encrypt the data + encryptor := connState.encryptor + encryptedData := encryptor.Seal(nil, nonce, data, nil) + + // Format the encrypted message: + // [2-byte nonce length][nonce][encrypted data] + finalData := make([]byte, 2+len(nonce)+len(encryptedData)) + binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce))) + copy(finalData[2:], nonce) + copy(finalData[2+len(nonce):], encryptedData) + + // Send the encrypted message + return writer.Write(conn, messageType, finalData) + }) +} + +func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader { + return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) { + // Read the encrypted message + messageType, encryptedData, err := reader.Read(conn) + if err != nil { + return messageType, encryptedData, err + } + + encrypt.mux.Lock() + collection, exists := encrypt.collection[conn] + encrypt.mux.Unlock() + + if !exists || collection.decryptor == nil || len(encryptedData) < 2 { + // No decryption configured or data too short to be encrypted + // Pass through as-is + return messageType, encryptedData, nil + } + + // Extract nonce length + nonceLen := binary.BigEndian.Uint16(encryptedData[:2]) + + // Ensure we have enough data for the nonce and at least some ciphertext + if len(encryptedData) < int(2+nonceLen) { + return messageType, encryptedData, fmt.Errorf("encrypted data too short") + } + + // Extract nonce and ciphertext + nonce := encryptedData[2 : 2+nonceLen] + ciphertext := encryptedData[2+nonceLen:] + + // Decrypt the data + decryptor := collection.decryptor + plaintext, err := decryptor.Open(nil, nonce, ciphertext, nil) + if err != nil { + return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err) + } + + return messageType, plaintext, nil + }) +} + +func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + delete(encrypt.collection, connection) +} + +func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {} + +func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {} + +func (encrypt *Interceptor) Close() error { + encrypt.mux.Lock() + defer encrypt.mux.Unlock() + + encrypt.collection = make(map[*websocket.Conn]*collection) + + return nil +} diff --git a/pkg/interceptor/errors.go b/pkg/interceptor/errors.go index fb1a506..c372ea9 100644 --- a/pkg/interceptor/errors.go +++ b/pkg/interceptor/errors.go @@ -20,30 +20,30 @@ func flattenErrs(errs []error) error { type multiError []error -func (me multiError) Error() string { - var errstrings []string +func (errs multiError) Error() string { + var errStrings []string - for _, err := range me { + for _, err := range errs { if err != nil { - errstrings = append(errstrings, err.Error()) + errStrings = append(errStrings, err.Error()) } } - if len(errstrings) == 0 { + if len(errStrings) == 0 { return "multiError must contain multiple error but is empty" } - return strings.Join(errstrings, "\n") + return strings.Join(errStrings, "\n") } -func (me multiError) Is(err error) bool { - for _, e := range me { +func (errs multiError) Is(err error) bool { + for _, e := range errs { if errors.Is(e, err) { return true } - var me2 multiError - if errors.As(e, &me2) { - if me2.Is(err) { + var errs2 multiError + if errors.As(e, &errs2) { + if errs2.Is(err) { return true } } diff --git a/pkg/interceptor/interceptor.go b/pkg/interceptor/interceptor.go index 3972aa2..8c107da 100644 --- a/pkg/interceptor/interceptor.go +++ b/pkg/interceptor/interceptor.go @@ -1,27 +1,37 @@ package interceptor import ( + "context" "io" - "github.com/harshabose/skyline_sonata/serve/pkg/message" + "github.com/coder/websocket" ) +// Registry maintains a collection of interceptor factories that can be used to +// build a chain of interceptors for a given context and ID. type Registry struct { factories []Factory } +// Register adds a new interceptor factory to the registry. +// Factories are stored in the order they're registered, which determines +// the order of interceptors in the resulting chain. func (registry *Registry) Register(factory Factory) { registry.factories = append(registry.factories, factory) } -func (registry *Registry) Build(id string) (Interceptor, error) { +// Build creates a chain of interceptors by invoking each registered factory. +// If no factories are registered, returns a no-op interceptor. +// The context and ID are passed to each factory to allow for customized +// interceptor creation based on request context or client identity. +func (registry *Registry) Build(ctx context.Context, id string) (Interceptor, error) { if len(registry.factories) == 0 { - return &NoInterceptor{}, nil + return &NoOpInterceptor{}, nil } interceptors := make([]Interceptor, 0) for _, factory := range registry.factories { - interceptor, err := factory.NewInterceptor(id) + interceptor, err := factory.NewInterceptor(ctx, id) if err != nil { return nil, err } @@ -34,31 +44,119 @@ func (registry *Registry) Build(id string) (Interceptor, error) { // Factory provides an interface for constructing interceptors type Factory interface { - NewInterceptor(id string) (Interceptor, error) + NewInterceptor(context.Context, string) (Interceptor, error) } -// Interceptor are transformers which bind to incoming, outgoing and connection of a client of the websocket. This can -// be used to add functionalities to the websocket connection. +// Interceptor defines a transformer that can modify the behavior of websocket connections. +// Interceptors can bind to the connection itself, its writers (for outgoing messages), +// and its readers (for incoming messages). This pattern enables adding functionalities +// like logging, encryption, compression, rate limiting, or analytics to websocket +// connections without modifying the core websocket handling code. type Interceptor interface { - // BindIncoming binds to incoming messages to a client - BindIncoming(IncomingReader) IncomingReader + // BindSocketConnection is called when a new websocket connection is established. + // It gives the interceptor an opportunity to set up any connection-specific + // state or perform initialization tasks for the given connection. + // Returns an error if the binding process fails, which typically would + // result in the connection being rejected. + BindSocketConnection(*websocket.Conn) error - // BindOutgoing binds to outgoing messages from a client - BindOutgoing(OutgoingWriter) OutgoingWriter + // BindSocketWriter wraps a writer that handles messages going out to clients. + // The interceptor receives the original writer and returns a modified writer + // that adds the interceptor's functionality. For example, an encryption interceptor + // would return a writer that encrypts messages before passing them to the original writer. + // The returned writer will be used for all future write operations on the connection. + BindSocketWriter(Writer) Writer - // BindConnection binds to the websocket connection itself - BindConnection(Connection) Connection + // BindSocketReader wraps a reader that handles messages coming in from clients. + // The interceptor receives the original reader and returns a modified reader + // that adds the interceptor's functionality. For example, a logging interceptor + // would return a reader that logs messages after receiving them from the original reader. + // The returned reader will be used for all future read operations on the connection. + BindSocketReader(Reader) Reader + // UnBindSocketConnection is called when a websocket connection is closed or removed. + // It cleans up any connection-specific resources and state maintained by the interceptor + // for the given connection, removing it from the collection map to prevent memory leaks. + UnBindSocketConnection(*websocket.Conn) + + // UnBindSocketWriter is called when a writer is being removed or when the + // connection is closing. This gives the interceptor an opportunity to clean up + // any resources or state associated with the writer. The interceptor should + // release any references to the writer to prevent memory leaks. + UnBindSocketWriter(Writer) + + // UnBindSocketReader is called when a reader is being removed or when the + // connection is closing. This gives the interceptor an opportunity to clean up + // any resources or state associated with the reader. The interceptor should + // release any references to the reader to prevent memory leaks. + UnBindSocketReader(Reader) + + // Closer interface implementation for resource cleanup. + // Close is called when the interceptor itself is being shut down. + // It should clean up any global resources held by the interceptor. io.Closer } -type IncomingReader interface { - Read([]byte) (int, error) +// Writer is an interface for writing messages to a websocket connection +type Writer interface { + // Write sends a message to the connection + // Takes the connection, message type, and data to write + // Returns any error encountered during writing + Write(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error } -type OutgoingWriter interface { - Write(message message.BaseMessage) (int, error) +// Reader is an interface for reading messages from a websocket connection +type Reader interface { + // Read reads a message from the connection + // Returns the message type, message data, and any error + Read(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error) } -type Connection interface { +// ReaderFunc is a function type that implements the Reader interface +type ReaderFunc func(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error) + +// Read implements the Reader interface for ReaderFunc +func (f ReaderFunc) Read(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error) { + return f(conn) +} + +// WriterFunc is a function type that implements the Writer interface +type WriterFunc func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error + +// Write implements the Writer interface for WriterFunc +func (f WriterFunc) Write(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error { + return f(conn, messageType, data) +} + +// NoOpInterceptor implements the Interceptor interface with no-op methods. +// It's used as a fallback when no interceptors are configured or as a base +// struct that other interceptors can embed to avoid implementing all methods. +type NoOpInterceptor struct{} + +// BindSocketConnection is a no-op implementation that accepts any connection. +func (interceptor *NoOpInterceptor) BindSocketConnection(_ *websocket.Conn) error { + return nil +} + +// BindSocketWriter returns the original writer without modification. +func (interceptor *NoOpInterceptor) BindSocketWriter(writer Writer) Writer { + return writer +} + +// BindSocketReader returns the original reader without modification. +func (interceptor *NoOpInterceptor) BindSocketReader(reader Reader) Reader { + return reader +} + +func (interceptor *NoOpInterceptor) UnBindSocketConnection(_ *websocket.Conn) {} + +// UnBindSocketWriter performs no cleanup operations. +func (interceptor *NoOpInterceptor) UnBindSocketWriter(_ Writer) {} + +// UnBindSocketReader performs no cleanup operations. +func (interceptor *NoOpInterceptor) UnBindSocketReader(_ Reader) {} + +// Close performs no cleanup operations. +func (interceptor *NoOpInterceptor) Close() error { + return nil } diff --git a/pkg/interceptor/log/factory.go b/pkg/interceptor/log/factory.go new file mode 100644 index 0000000..7330d54 --- /dev/null +++ b/pkg/interceptor/log/factory.go @@ -0,0 +1 @@ +package log diff --git a/pkg/interceptor/log/interceptor.go b/pkg/interceptor/log/interceptor.go new file mode 100644 index 0000000..7330d54 --- /dev/null +++ b/pkg/interceptor/log/interceptor.go @@ -0,0 +1 @@ +package log diff --git a/pkg/interceptor/no_interceptor.go b/pkg/interceptor/no_interceptor.go deleted file mode 100644 index 82e5f63..0000000 --- a/pkg/interceptor/no_interceptor.go +++ /dev/null @@ -1,19 +0,0 @@ -package interceptor - -type NoInterceptor struct{} - -func (interceptor *NoInterceptor) BindIncoming(reader IncomingReader) IncomingReader { - return reader -} - -func (interceptor *NoInterceptor) BindOutgoing(writer OutgoingWriter) OutgoingWriter { - return writer -} - -func (interceptor *NoInterceptor) BindConnection(connection Connection) Connection { - return connection -} - -func (interceptor *NoInterceptor) Close() error { - return nil -} diff --git a/pkg/interceptor/ping/factory.go b/pkg/interceptor/ping/factory.go new file mode 100644 index 0000000..13142a6 --- /dev/null +++ b/pkg/interceptor/ping/factory.go @@ -0,0 +1,50 @@ +package ping + +import ( + "context" + "time" +) + +type Option = func(*Interceptor) error + +type InterceptorFactory struct { + opts []Option +} + +func WithInterval(interval time.Duration) Option { + return func(interceptor *Interceptor) error { + interceptor.interval = interval + return nil + } +} + +func WithStoreMax(max uint16) Option { + return func(interceptor *Interceptor) error { + interceptor.statsFactory.opts = append(interceptor.statsFactory.opts, withMax(max)) + return nil + } +} + +func CreateInterceptorFactory(options ...Option) *InterceptorFactory { + return &InterceptorFactory{ + opts: options, + } +} + +func (factory *InterceptorFactory) NewInterceptor(ctx context.Context, id string) (*Interceptor, error) { + pingInterceptor := &Interceptor{ + close: make(chan struct{}), + statsFactory: statsFactory{}, + ctx: ctx, + } + + for _, option := range factory.opts { + if err := option(pingInterceptor); err != nil { + return nil, err + } + } + + go pingInterceptor.loop() + + return pingInterceptor, nil +} diff --git a/pkg/interceptor/ping/interceptor.go b/pkg/interceptor/ping/interceptor.go new file mode 100644 index 0000000..9c712f4 --- /dev/null +++ b/pkg/interceptor/ping/interceptor.go @@ -0,0 +1,176 @@ +package ping + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/coder/websocket" + + "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" + "github.com/harshabose/skyline_sonata/serve/pkg/message" +) + +// collection combines connection-specific statistics and a writer reference, +// storing all per-connection state needed by the ping interceptor. +type collection struct { + *pings + interceptor.Writer +} + +// Interceptor implements a ping mechanism to maintain websocket connections. +// It periodically sends ping messages and tracks their responses to monitor +// connection health. +type Interceptor struct { + interceptor.NoOpInterceptor + interval time.Duration + statsFactory statsFactory + collection map[*websocket.Conn]*collection + mux sync.RWMutex + close chan struct{} + ctx context.Context +} + +// BindSocketConnection initializes tracking for a new websocket connection by +// creating pings for it and storing it in the collection map. The writer will +// be set later when BindSocketWriter is called for this connection. +func (ping *Interceptor) BindSocketConnection(connection *websocket.Conn) error { + ping.mux.Lock() + defer ping.mux.Unlock() + + stats, err := ping.statsFactory.createStats() + if err != nil { + return err + } + + ping.collection[connection] = &collection{stats, nil} + return nil +} + +// BindSocketWriter wraps a writer to store the writer for later. +func (ping *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer { + return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error { + // Store the writer for this connection + // Storing the writer allows other interceptors to perform their interceptions as well + ping.mux.Lock() + defer ping.mux.Unlock() + if _, exists := ping.collection[conn]; exists { + ping.collection[conn].Writer = writer + } + + // Pass through to original writer + // No manipulation of writer, just storing + return writer.Write(conn, messageType, data) + }) +} + +// BindSocketReader wraps a reader to handle pong responses. +func (ping *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader { + return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) { + messageType, data, err := reader.Read(conn) + if err != nil { + return messageType, data, err + } + + msg := &message.Pong{} + if err := msg.Unmarshal(data); err == nil { + // Message is Pong message + ping.collection[conn].pings.recordPong(msg) + } + + return messageType, data, nil + }) +} + +// UnBindSocketConnection removes a connection from the interceptor's tracking. +// This is called when a connection is closed, ensuring that resources +// associated with the connection are freed and preventing memory leaks. +func (ping *Interceptor) UnBindSocketConnection(connection *websocket.Conn) { + ping.mux.Lock() + defer ping.mux.Unlock() + + delete(ping.collection, connection) +} + +// UnBindSocketWriter performs cleanup when a writer is being removed. +// Since the writer references are stored by connection, writer don't need +// special cleanup for individual writers. +func (ping *Interceptor) UnBindSocketWriter(_ interceptor.Writer) { + // If left, unimplemented, NoOpInterceptor's default implementation will be used + // But, for reference, this method is implemented +} + +// UnBindSocketReader performs cleanup when a reader is being removed. +// Since the Interceptor don't maintain reader-specific state, no specific cleanup is needed. +func (ping *Interceptor) UnBindSocketReader(_ interceptor.Reader) { + // If left, unimplemented, NoOpInterceptor's default implementation will be used + // But, for reference, this method is implemented +} + +// Close shuts down the ping interceptor and cleans up all resources. +// It signals the background ping loop to stop, waits for confirmation +// that it has stopped, and cleans up any remaining connection state. +// This method is safe to call multiple times. +func (ping *Interceptor) Close() error { + select { + case ping.close <- struct{}{}: + // sent signal successfully + default: + // already closing/closed + } + + ping.mux.Lock() + defer ping.mux.Unlock() + ping.collection = make(map[*websocket.Conn]*collection) + + return nil +} + +func (ping *Interceptor) loop() { + ticker := time.NewTicker(ping.interval) + defer ticker.Stop() + + for { + select { + case <-ping.ctx.Done(): + return + case <-ping.close: + return + case <-ticker.C: + ping.mux.RLock() + + // Send ping messages to all connections + for conn, collection := range ping.collection { + if collection.Writer == nil { + fmt.Println("writer not bound yet; skipping...") + continue + } + + msg := message.CreatePingMessage(time.Now()) + data, err := msg.Marshal() + if err != nil { + fmt.Println("error while marshaling ping message; skipping...") + continue + } + + // Use the stored writer instead of sending through websocket.Conn.Write(...) + if err := collection.Writer.Write(conn, websocket.MessageText, data); err != nil { + fmt.Println("error while sending ping message; skipping...") + continue + } + + // Record successful ping + collection.pings.recordSentPing(msg) + } + + ping.mux.RUnlock() + } + } +} + +func (ping *Interceptor) sendPing() { + ping.mux.RLock() + defer ping.mux.RUnlock() + +} diff --git a/pkg/interceptor/ping/ping.go b/pkg/interceptor/ping/ping.go new file mode 100644 index 0000000..0b1c859 --- /dev/null +++ b/pkg/interceptor/ping/ping.go @@ -0,0 +1,140 @@ +package ping + +import ( + "time" + + "github.com/harshabose/skyline_sonata/serve/pkg/message" +) + +type ping struct { + rtt time.Duration // Round-trip time for ping-pong + timestamp time.Time // When this ping was recorded +} + +type pings struct { + pings []ping // Historical pings, limited by max capacity + max uint16 // Maximum number of pings to keep + recvd int // Total pongs received + count int // Total pings sent + recent ping // Most recent ping +} + +// recordPong updates pings based on a received pong message +func (s *pings) recordPong(msg *message.Pong) { + rtt := msg.Timestamp.Sub(msg.PingTimestamp) + + newStat := ping{ + rtt: rtt, + timestamp: time.Now(), + } + + s.recent = newStat + + if uint16(len(s.pings)) >= s.max { + if len(s.pings) > 0 { + s.pings = s.pings[1:] + } + } + s.pings = append(s.pings, newStat) + s.recordReceivedPong(msg) +} + +// recordSentPing increments the count of pings sent +func (s *pings) recordSentPing(_ *message.Ping) { + s.count++ +} + +func (s *pings) recordReceivedPong(_ *message.Pong) { + s.recvd++ +} + +// GetRecentRTT returns the most recent round-trip time +func (s *pings) GetRecentRTT() time.Duration { + return s.recent.rtt +} + +// GetAverageRTT calculates the average round-trip time +func (s *pings) GetAverageRTT() time.Duration { + if len(s.pings) == 0 { + return 0 + } + + var total time.Duration + for _, stat := range s.pings { + total += stat.rtt + } + + return total / time.Duration(len(s.pings)) +} + +// GetMaxRTT returns the maximum round-trip time observed +func (s *pings) GetMaxRTT() time.Duration { + if len(s.pings) == 0 { + return 0 + } + + var maxRTT time.Duration + for _, stat := range s.pings { + if stat.rtt > maxRTT { + maxRTT = stat.rtt + } + } + + return maxRTT +} + +// GetMinRTT returns the minimum round-trip time observed +func (s *pings) GetMinRTT() time.Duration { + if len(s.pings) == 0 { + return 0 + } + + minRTT := s.pings[0].rtt + for _, stat := range s.pings { + if stat.rtt < minRTT { + minRTT = stat.rtt + } + } + + return minRTT +} + +// GetSuccessRate returns the percentage of successful pings +func (s *pings) GetSuccessRate() float64 { + if s.count == 0 { + return 0 + } + + return 100.0 * (1.0 - float64(s.count-s.recvd)/float64(s.count)) +} + +type statsOption = func(*pings) error + +func withMax(max uint16) statsOption { + return func(s *pings) error { + s.max = max + return nil + } +} + +type statsFactory struct { + opts []statsOption +} + +func (factory *statsFactory) createStats() (*pings, error) { + stats := &pings{ + pings: make([]ping, 0), + max: ^uint16(0), + count: 0, + recvd: 0, + recent: ping{}, + } + + for _, option := range factory.opts { + if err := option(stats); err != nil { + return nil, err + } + } + + return stats, nil +} diff --git a/pkg/interceptor/room/factory.go b/pkg/interceptor/room/factory.go new file mode 100644 index 0000000..0191265 --- /dev/null +++ b/pkg/interceptor/room/factory.go @@ -0,0 +1 @@ +package room diff --git a/pkg/interceptor/room/interceptor.go b/pkg/interceptor/room/interceptor.go new file mode 100644 index 0000000..0191265 --- /dev/null +++ b/pkg/interceptor/room/interceptor.go @@ -0,0 +1 @@ +package room diff --git a/pkg/interceptor/stats/factory.go b/pkg/interceptor/stats/factory.go new file mode 100644 index 0000000..43b4fd5 --- /dev/null +++ b/pkg/interceptor/stats/factory.go @@ -0,0 +1 @@ +package stats diff --git a/pkg/interceptor/stats/interceptor.go b/pkg/interceptor/stats/interceptor.go new file mode 100644 index 0000000..43b4fd5 --- /dev/null +++ b/pkg/interceptor/stats/interceptor.go @@ -0,0 +1 @@ +package stats diff --git a/pkg/message/message.go b/pkg/message/message.go index 6a9d915..ebec213 100644 --- a/pkg/message/message.go +++ b/pkg/message/message.go @@ -2,13 +2,10 @@ package message type BaseMessage interface { Marshal() ([]byte, error) + Unmarshal([]byte) error } type Header struct { SourceID string `json:"source_id"` DestinationID string `json:"destination_id"` } - -type Message struct { - Header -} diff --git a/pkg/message/messages.go b/pkg/message/messages.go new file mode 100644 index 0000000..f788f58 --- /dev/null +++ b/pkg/message/messages.go @@ -0,0 +1,54 @@ +package message + +import ( + "encoding/json" + "time" +) + +type Ping struct { + Header + Timestamp time.Time `json:"timestamp"` +} + +func (ping *Ping) Marshal() ([]byte, error) { + return json.Marshal(ping) +} + +func (ping *Ping) Unmarshal(data []byte) error { + return json.Unmarshal(data, ping) +} + +func CreatePingMessage(timestamp time.Time) *Ping { + return &Ping{ + Header: Header{ + SourceID: "server", + DestinationID: "unknown", + }, + Timestamp: timestamp, + } +} + +type Pong struct { + Header + PingTimestamp time.Time `json:"ping_timestamp"` + Timestamp time.Time `json:"timestamp"` +} + +func (pong *Pong) Marshal() ([]byte, error) { + return json.Marshal(pong) +} + +func (pong *Pong) Unmarshal(data []byte) error { + return json.Unmarshal(data, pong) +} + +func CreatePongMessage(pingTimestamp time.Time, timestamp time.Time) *Pong { + return &Pong{ + Header: Header{ + SourceID: "unknown", + DestinationID: "server", + }, + PingTimestamp: pingTimestamp, + Timestamp: timestamp, + } +} diff --git a/pkg/socket/client.go b/pkg/socket/client.go index 69513dc..da15339 100644 --- a/pkg/socket/client.go +++ b/pkg/socket/client.go @@ -1 +1,16 @@ package socket + +import "github.com/coder/websocket" + +type clients = map[string]*client + +type client struct { + id string + connection *websocket.Conn +} + +func createClient(connection *websocket.Conn) *client { + return &client{ + connection: connection, + } +} diff --git a/pkg/socket/settings.go b/pkg/socket/settings.go index e69e75a..8f6333a 100644 --- a/pkg/socket/settings.go +++ b/pkg/socket/settings.go @@ -1,8 +1,83 @@ package socket -type Settings struct { +import ( + "crypto/tls" + "net/http" + "time" + + "golang.org/x/time/rate" +) + +type apiSettings struct { } -func RegisterDefaultSettings(settings *Settings) error { +func registerDefaultAPISettings(settings *apiSettings) error { + return nil +} + +type settings struct { + // Server settings + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + ReadHeaderTimeout time.Duration + MaxHeaderBytes int + ShutdownTimeout time.Duration + + // TLS configuration + TLSConfig *tls.Config + TLSCertFile string + TLSKeyFile string + + // Connection settings + MaxConnections int + ConnectionTimeout time.Duration + + // WebSocket specific + PingInterval time.Duration + PongWait time.Duration + WriteWait time.Duration + MessageSizeLimit int64 + + // Router settings + BasePath string + EnableCORS bool + CORSAllowOrigins []string + CORSAllowMethods []string + CORSAllowHeaders []string + + // Middleware + EnableLogging bool + EnableCompression bool + RateLimiter *rate.Limiter +} + +func (s *settings) apply(socket *Socket) { + socket.server.ReadTimeout = s.ReadTimeout + socket.server.WriteTimeout = s.WriteTimeout + socket.server.IdleTimeout = s.IdleTimeout + socket.server.ReadHeaderTimeout = s.ReadHeaderTimeout + socket.server.MaxHeaderBytes = s.MaxHeaderBytes + + socket.server.TLSConfig = s.TLSConfig + + if s.EnableCORS { + // s.applyCORS() + } + + if s.EnableLogging { + + } + + if s.EnableCompression { + + } +} + +func (s *settings) applyCORS(handler *http.HandlerFunc) { + +} + +func registerDefaultSettings(settings *settings) error { return nil } diff --git a/pkg/socket/socket.go b/pkg/socket/socket.go index b912702..fd0b9f5 100644 --- a/pkg/socket/socket.go +++ b/pkg/socket/socket.go @@ -1,25 +1,24 @@ package socket import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "github.com/coder/websocket" "github.com/harshabose/skyline_sonata/serve/pkg/interceptor" ) type API struct { - settings *Settings + settings *apiSettings interceptorRegistry *interceptor.Registry } type APIOption = func(*API) error -func WithSocketSettings(settings *Settings) APIOption { - return func(api *API) error { - api.settings = settings - return nil - } -} - func WithInterceptorRegistry(registry *interceptor.Registry) APIOption { return func(api *API) error { api.interceptorRegistry = registry @@ -27,55 +26,103 @@ func WithInterceptorRegistry(registry *interceptor.Registry) APIOption { } } -func CreateSocketFactory(options ...APIOption) (*API, error) { +func CreateAPI(options ...APIOption) (*API, error) { api := &API{ - settings: nil, + settings: &apiSettings{}, interceptorRegistry: nil, } + if err := registerDefaultAPISettings(api.settings); err != nil { + return nil, err + } + for _, option := range options { if err := option(api); err != nil { return nil, err } } - if api.settings == nil { - api.settings = &Settings{} - if err := RegisterDefaultSettings(api.settings); err != nil { - return nil, err - } - } - return api, nil } -func (api *API) CreateWebSocket(id string, options ...Option) (*Socket, error) { +func (api *API) CreateWebSocket(ctx context.Context, id string, options ...Option) (*Socket, error) { socket := &Socket{ - id: id, + id: id, + settings: &settings{}, + socketAcceptOptions: &websocket.AcceptOptions{}, + ctx: ctx, } - interceptors, err := api.interceptorRegistry.Build(id) + interceptors, err := api.interceptorRegistry.Build(ctx, id) if err != nil { return nil, err } socket.interceptor = interceptors + if err := registerDefaultSettings(socket.settings); err != nil { + return nil, err + } + for _, option := range options { if err := option(socket); err != nil { return nil, err } } - return socket, nil + return socket.setup(), nil } type Socket struct { - id string - connections map[string]*websocket.Conn - interceptor interceptor.Interceptor + id string + settings *settings + server *http.Server + router *http.ServeMux + handlerFunc *http.HandlerFunc + socketAcceptOptions *websocket.AcceptOptions + interceptor interceptor.Interceptor + mux sync.RWMutex + ctx context.Context } -func (socket *Socket) Serve() { +func (socket *Socket) setup() *Socket { + socket.router = http.NewServeMux() + socket.server = &http.Server{} + // socket.handlerFunc = socket.wssHandler + + socket.settings.apply(socket) + + return socket +} + +func (socket *Socket) serve() error { + defer socket.close() + + for { + select { + case <-socket.ctx.Done(): + return nil + default: + if err := socket.server.ListenAndServeTLS(socket.settings.TLSCertFile, socket.settings.TLSKeyFile); err != nil { + fmt.Println(errors.New("error while serving HTTP server")) + fmt.Println("trying again...") + } + } + } +} + +func (socket *Socket) baseHandler(w http.ResponseWriter, r *http.Request) { + connection, err := websocket.Accept(w, r, socket.socketAcceptOptions) + if err != nil { + fmt.Println(errors.New("error while accepting socket connection")) + } + + if err := socket.interceptor.BindSocketConnection(connection); err != nil { + fmt.Println("error while handling client:", err.Error()) + return + } +} + +func (socket *Socket) close() { }