mirror of
https://github.com/harshabose/serve.git
synced 2026-04-22 23:07:27 +08:00
added ping and encrypt interceptors
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
/.idea
|
||||
/third_party/
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
module github.com/harshabose/skyline_sonata/serve
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/coder/websocket v1.8.12
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
@@ -0,0 +1 @@
|
||||
package auth
|
||||
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
}
|
||||
|
||||
func (auth *Interceptor) BindConnection(connection *websocket.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func (auth *Interceptor) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package interceptor
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
|
||||
type Chain struct {
|
||||
interceptors []Interceptor
|
||||
}
|
||||
@@ -8,28 +10,47 @@ func CreateChain(interceptors []Interceptor) *Chain {
|
||||
return &Chain{interceptors: interceptors}
|
||||
}
|
||||
|
||||
func (chain *Chain) BindIncoming(reader IncomingReader) IncomingReader {
|
||||
func (chain *Chain) BindSocketConnection(connection *websocket.Conn) error {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.BindIncoming(reader)
|
||||
if err := interceptor.BindSocketConnection(connection); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return reader
|
||||
return nil
|
||||
}
|
||||
|
||||
func (chain *Chain) BindOutgoing(writer OutgoingWriter) OutgoingWriter {
|
||||
func (chain *Chain) BindSocketWriter(writer Writer) Writer {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.BindOutgoing(writer)
|
||||
writer = interceptor.BindSocketWriter(writer)
|
||||
}
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
func (chain *Chain) BindConnection(connection Connection) Connection {
|
||||
func (chain *Chain) BindSocketReader(reader Reader) Reader {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.BindConnection(connection)
|
||||
reader = interceptor.BindSocketReader(reader)
|
||||
}
|
||||
|
||||
return connection
|
||||
return reader
|
||||
}
|
||||
|
||||
func (chain *Chain) UnBindSocketConnection(connection *websocket.Conn) {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.UnBindSocketConnection(connection)
|
||||
}
|
||||
}
|
||||
|
||||
func (chain *Chain) UnBindSocketWriter(writer Writer) {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.UnBindSocketWriter(writer)
|
||||
}
|
||||
}
|
||||
|
||||
func (chain *Chain) UnBindSocketReader(reader Reader) {
|
||||
for _, interceptor := range chain.interceptors {
|
||||
interceptor.UnBindSocketReader(reader)
|
||||
}
|
||||
}
|
||||
|
||||
func (chain *Chain) Close() error {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
package encrypt
|
||||
@@ -0,0 +1,154 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
type collection struct {
|
||||
encryptor cipher.AEAD // For encrypting outgoing messages
|
||||
decryptor cipher.AEAD // For decrypting incoming messages
|
||||
}
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
collection map[*websocket.Conn]*collection
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
// Generate a key for this connection
|
||||
key := make([]byte, 32) // AES-256
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return fmt.Errorf("failed to generate encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode encryptor
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
_ = gcm
|
||||
|
||||
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
|
||||
// TODO: Store different keys for encryption and decryption
|
||||
|
||||
encrypt.collection[connection] = &collection{
|
||||
encryptor: gcm,
|
||||
decryptor: gcm,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
|
||||
encrypt.mux.Lock()
|
||||
connState, exists := encrypt.collection[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || connState.encryptor == nil {
|
||||
// No encryption configured for this connection yet
|
||||
// Pass through unencrypted
|
||||
return writer.Write(conn, messageType, data)
|
||||
}
|
||||
|
||||
// Generate a nonce for this message
|
||||
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
encryptor := connState.encryptor
|
||||
encryptedData := encryptor.Seal(nil, nonce, data, nil)
|
||||
|
||||
// Format the encrypted message:
|
||||
// [2-byte nonce length][nonce][encrypted data]
|
||||
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
|
||||
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
|
||||
copy(finalData[2:], nonce)
|
||||
copy(finalData[2+len(nonce):], encryptedData)
|
||||
|
||||
// Send the encrypted message
|
||||
return writer.Write(conn, messageType, finalData)
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
|
||||
// Read the encrypted message
|
||||
messageType, encryptedData, err := reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, err
|
||||
}
|
||||
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.collection[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || collection.decryptor == nil || len(encryptedData) < 2 {
|
||||
// No decryption configured or data too short to be encrypted
|
||||
// Pass through as-is
|
||||
return messageType, encryptedData, nil
|
||||
}
|
||||
|
||||
// Extract nonce length
|
||||
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
|
||||
|
||||
// Ensure we have enough data for the nonce and at least some ciphertext
|
||||
if len(encryptedData) < int(2+nonceLen) {
|
||||
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce := encryptedData[2 : 2+nonceLen]
|
||||
ciphertext := encryptedData[2+nonceLen:]
|
||||
|
||||
// Decrypt the data
|
||||
decryptor := collection.decryptor
|
||||
plaintext, err := decryptor.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return messageType, plaintext, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
delete(encrypt.collection, connection)
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
|
||||
|
||||
func (encrypt *Interceptor) Close() error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
encrypt.collection = make(map[*websocket.Conn]*collection)
|
||||
|
||||
return nil
|
||||
}
|
||||
+11
-11
@@ -20,30 +20,30 @@ func flattenErrs(errs []error) error {
|
||||
|
||||
type multiError []error
|
||||
|
||||
func (me multiError) Error() string {
|
||||
var errstrings []string
|
||||
func (errs multiError) Error() string {
|
||||
var errStrings []string
|
||||
|
||||
for _, err := range me {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
errstrings = append(errstrings, err.Error())
|
||||
errStrings = append(errStrings, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errstrings) == 0 {
|
||||
if len(errStrings) == 0 {
|
||||
return "multiError must contain multiple error but is empty"
|
||||
}
|
||||
|
||||
return strings.Join(errstrings, "\n")
|
||||
return strings.Join(errStrings, "\n")
|
||||
}
|
||||
|
||||
func (me multiError) Is(err error) bool {
|
||||
for _, e := range me {
|
||||
func (errs multiError) Is(err error) bool {
|
||||
for _, e := range errs {
|
||||
if errors.Is(e, err) {
|
||||
return true
|
||||
}
|
||||
var me2 multiError
|
||||
if errors.As(e, &me2) {
|
||||
if me2.Is(err) {
|
||||
var errs2 multiError
|
||||
if errors.As(e, &errs2) {
|
||||
if errs2.Is(err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+116
-18
@@ -1,27 +1,37 @@
|
||||
package interceptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// Registry maintains a collection of interceptor factories that can be used to
|
||||
// build a chain of interceptors for a given context and ID.
|
||||
type Registry struct {
|
||||
factories []Factory
|
||||
}
|
||||
|
||||
// Register adds a new interceptor factory to the registry.
|
||||
// Factories are stored in the order they're registered, which determines
|
||||
// the order of interceptors in the resulting chain.
|
||||
func (registry *Registry) Register(factory Factory) {
|
||||
registry.factories = append(registry.factories, factory)
|
||||
}
|
||||
|
||||
func (registry *Registry) Build(id string) (Interceptor, error) {
|
||||
// Build creates a chain of interceptors by invoking each registered factory.
|
||||
// If no factories are registered, returns a no-op interceptor.
|
||||
// The context and ID are passed to each factory to allow for customized
|
||||
// interceptor creation based on request context or client identity.
|
||||
func (registry *Registry) Build(ctx context.Context, id string) (Interceptor, error) {
|
||||
if len(registry.factories) == 0 {
|
||||
return &NoInterceptor{}, nil
|
||||
return &NoOpInterceptor{}, nil
|
||||
}
|
||||
|
||||
interceptors := make([]Interceptor, 0)
|
||||
for _, factory := range registry.factories {
|
||||
interceptor, err := factory.NewInterceptor(id)
|
||||
interceptor, err := factory.NewInterceptor(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,31 +44,119 @@ func (registry *Registry) Build(id string) (Interceptor, error) {
|
||||
|
||||
// Factory provides an interface for constructing interceptors
|
||||
type Factory interface {
|
||||
NewInterceptor(id string) (Interceptor, error)
|
||||
NewInterceptor(context.Context, string) (Interceptor, error)
|
||||
}
|
||||
|
||||
// Interceptor are transformers which bind to incoming, outgoing and connection of a client of the websocket. This can
|
||||
// be used to add functionalities to the websocket connection.
|
||||
// Interceptor defines a transformer that can modify the behavior of websocket connections.
|
||||
// Interceptors can bind to the connection itself, its writers (for outgoing messages),
|
||||
// and its readers (for incoming messages). This pattern enables adding functionalities
|
||||
// like logging, encryption, compression, rate limiting, or analytics to websocket
|
||||
// connections without modifying the core websocket handling code.
|
||||
type Interceptor interface {
|
||||
// BindIncoming binds to incoming messages to a client
|
||||
BindIncoming(IncomingReader) IncomingReader
|
||||
// BindSocketConnection is called when a new websocket connection is established.
|
||||
// It gives the interceptor an opportunity to set up any connection-specific
|
||||
// state or perform initialization tasks for the given connection.
|
||||
// Returns an error if the binding process fails, which typically would
|
||||
// result in the connection being rejected.
|
||||
BindSocketConnection(*websocket.Conn) error
|
||||
|
||||
// BindOutgoing binds to outgoing messages from a client
|
||||
BindOutgoing(OutgoingWriter) OutgoingWriter
|
||||
// BindSocketWriter wraps a writer that handles messages going out to clients.
|
||||
// The interceptor receives the original writer and returns a modified writer
|
||||
// that adds the interceptor's functionality. For example, an encryption interceptor
|
||||
// would return a writer that encrypts messages before passing them to the original writer.
|
||||
// The returned writer will be used for all future write operations on the connection.
|
||||
BindSocketWriter(Writer) Writer
|
||||
|
||||
// BindConnection binds to the websocket connection itself
|
||||
BindConnection(Connection) Connection
|
||||
// BindSocketReader wraps a reader that handles messages coming in from clients.
|
||||
// The interceptor receives the original reader and returns a modified reader
|
||||
// that adds the interceptor's functionality. For example, a logging interceptor
|
||||
// would return a reader that logs messages after receiving them from the original reader.
|
||||
// The returned reader will be used for all future read operations on the connection.
|
||||
BindSocketReader(Reader) Reader
|
||||
|
||||
// UnBindSocketConnection is called when a websocket connection is closed or removed.
|
||||
// It cleans up any connection-specific resources and state maintained by the interceptor
|
||||
// for the given connection, removing it from the collection map to prevent memory leaks.
|
||||
UnBindSocketConnection(*websocket.Conn)
|
||||
|
||||
// UnBindSocketWriter is called when a writer is being removed or when the
|
||||
// connection is closing. This gives the interceptor an opportunity to clean up
|
||||
// any resources or state associated with the writer. The interceptor should
|
||||
// release any references to the writer to prevent memory leaks.
|
||||
UnBindSocketWriter(Writer)
|
||||
|
||||
// UnBindSocketReader is called when a reader is being removed or when the
|
||||
// connection is closing. This gives the interceptor an opportunity to clean up
|
||||
// any resources or state associated with the reader. The interceptor should
|
||||
// release any references to the reader to prevent memory leaks.
|
||||
UnBindSocketReader(Reader)
|
||||
|
||||
// Closer interface implementation for resource cleanup.
|
||||
// Close is called when the interceptor itself is being shut down.
|
||||
// It should clean up any global resources held by the interceptor.
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type IncomingReader interface {
|
||||
Read([]byte) (int, error)
|
||||
// Writer is an interface for writing messages to a websocket connection
|
||||
type Writer interface {
|
||||
// Write sends a message to the connection
|
||||
// Takes the connection, message type, and data to write
|
||||
// Returns any error encountered during writing
|
||||
Write(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error
|
||||
}
|
||||
|
||||
type OutgoingWriter interface {
|
||||
Write(message message.BaseMessage) (int, error)
|
||||
// Reader is an interface for reading messages from a websocket connection
|
||||
type Reader interface {
|
||||
// Read reads a message from the connection
|
||||
// Returns the message type, message data, and any error
|
||||
Read(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error)
|
||||
}
|
||||
|
||||
type Connection interface {
|
||||
// ReaderFunc is a function type that implements the Reader interface
|
||||
type ReaderFunc func(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error)
|
||||
|
||||
// Read implements the Reader interface for ReaderFunc
|
||||
func (f ReaderFunc) Read(conn *websocket.Conn) (messageType websocket.MessageType, data []byte, err error) {
|
||||
return f(conn)
|
||||
}
|
||||
|
||||
// WriterFunc is a function type that implements the Writer interface
|
||||
type WriterFunc func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error
|
||||
|
||||
// Write implements the Writer interface for WriterFunc
|
||||
func (f WriterFunc) Write(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
|
||||
return f(conn, messageType, data)
|
||||
}
|
||||
|
||||
// NoOpInterceptor implements the Interceptor interface with no-op methods.
|
||||
// It's used as a fallback when no interceptors are configured or as a base
|
||||
// struct that other interceptors can embed to avoid implementing all methods.
|
||||
type NoOpInterceptor struct{}
|
||||
|
||||
// BindSocketConnection is a no-op implementation that accepts any connection.
|
||||
func (interceptor *NoOpInterceptor) BindSocketConnection(_ *websocket.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindSocketWriter returns the original writer without modification.
|
||||
func (interceptor *NoOpInterceptor) BindSocketWriter(writer Writer) Writer {
|
||||
return writer
|
||||
}
|
||||
|
||||
// BindSocketReader returns the original reader without modification.
|
||||
func (interceptor *NoOpInterceptor) BindSocketReader(reader Reader) Reader {
|
||||
return reader
|
||||
}
|
||||
|
||||
func (interceptor *NoOpInterceptor) UnBindSocketConnection(_ *websocket.Conn) {}
|
||||
|
||||
// UnBindSocketWriter performs no cleanup operations.
|
||||
func (interceptor *NoOpInterceptor) UnBindSocketWriter(_ Writer) {}
|
||||
|
||||
// UnBindSocketReader performs no cleanup operations.
|
||||
func (interceptor *NoOpInterceptor) UnBindSocketReader(_ Reader) {}
|
||||
|
||||
// Close performs no cleanup operations.
|
||||
func (interceptor *NoOpInterceptor) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
package log
|
||||
@@ -0,0 +1 @@
|
||||
package log
|
||||
@@ -1,19 +0,0 @@
|
||||
package interceptor
|
||||
|
||||
type NoInterceptor struct{}
|
||||
|
||||
func (interceptor *NoInterceptor) BindIncoming(reader IncomingReader) IncomingReader {
|
||||
return reader
|
||||
}
|
||||
|
||||
func (interceptor *NoInterceptor) BindOutgoing(writer OutgoingWriter) OutgoingWriter {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (interceptor *NoInterceptor) BindConnection(connection Connection) Connection {
|
||||
return connection
|
||||
}
|
||||
|
||||
func (interceptor *NoInterceptor) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Option = func(*Interceptor) error
|
||||
|
||||
type InterceptorFactory struct {
|
||||
opts []Option
|
||||
}
|
||||
|
||||
func WithInterval(interval time.Duration) Option {
|
||||
return func(interceptor *Interceptor) error {
|
||||
interceptor.interval = interval
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithStoreMax(max uint16) Option {
|
||||
return func(interceptor *Interceptor) error {
|
||||
interceptor.statsFactory.opts = append(interceptor.statsFactory.opts, withMax(max))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func CreateInterceptorFactory(options ...Option) *InterceptorFactory {
|
||||
return &InterceptorFactory{
|
||||
opts: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (factory *InterceptorFactory) NewInterceptor(ctx context.Context, id string) (*Interceptor, error) {
|
||||
pingInterceptor := &Interceptor{
|
||||
close: make(chan struct{}),
|
||||
statsFactory: statsFactory{},
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
for _, option := range factory.opts {
|
||||
if err := option(pingInterceptor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
go pingInterceptor.loop()
|
||||
|
||||
return pingInterceptor, nil
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
// collection combines connection-specific statistics and a writer reference,
|
||||
// storing all per-connection state needed by the ping interceptor.
|
||||
type collection struct {
|
||||
*pings
|
||||
interceptor.Writer
|
||||
}
|
||||
|
||||
// Interceptor implements a ping mechanism to maintain websocket connections.
|
||||
// It periodically sends ping messages and tracks their responses to monitor
|
||||
// connection health.
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
interval time.Duration
|
||||
statsFactory statsFactory
|
||||
collection map[*websocket.Conn]*collection
|
||||
mux sync.RWMutex
|
||||
close chan struct{}
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// BindSocketConnection initializes tracking for a new websocket connection by
|
||||
// creating pings for it and storing it in the collection map. The writer will
|
||||
// be set later when BindSocketWriter is called for this connection.
|
||||
func (ping *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
|
||||
ping.mux.Lock()
|
||||
defer ping.mux.Unlock()
|
||||
|
||||
stats, err := ping.statsFactory.createStats()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ping.collection[connection] = &collection{stats, nil}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindSocketWriter wraps a writer to store the writer for later.
|
||||
func (ping *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
|
||||
// Store the writer for this connection
|
||||
// Storing the writer allows other interceptors to perform their interceptions as well
|
||||
ping.mux.Lock()
|
||||
defer ping.mux.Unlock()
|
||||
if _, exists := ping.collection[conn]; exists {
|
||||
ping.collection[conn].Writer = writer
|
||||
}
|
||||
|
||||
// Pass through to original writer
|
||||
// No manipulation of writer, just storing
|
||||
return writer.Write(conn, messageType, data)
|
||||
})
|
||||
}
|
||||
|
||||
// BindSocketReader wraps a reader to handle pong responses.
|
||||
func (ping *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
|
||||
messageType, data, err := reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, data, err
|
||||
}
|
||||
|
||||
msg := &message.Pong{}
|
||||
if err := msg.Unmarshal(data); err == nil {
|
||||
// Message is Pong message
|
||||
ping.collection[conn].pings.recordPong(msg)
|
||||
}
|
||||
|
||||
return messageType, data, nil
|
||||
})
|
||||
}
|
||||
|
||||
// UnBindSocketConnection removes a connection from the interceptor's tracking.
|
||||
// This is called when a connection is closed, ensuring that resources
|
||||
// associated with the connection are freed and preventing memory leaks.
|
||||
func (ping *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
|
||||
ping.mux.Lock()
|
||||
defer ping.mux.Unlock()
|
||||
|
||||
delete(ping.collection, connection)
|
||||
}
|
||||
|
||||
// UnBindSocketWriter performs cleanup when a writer is being removed.
|
||||
// Since the writer references are stored by connection, writer don't need
|
||||
// special cleanup for individual writers.
|
||||
func (ping *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {
|
||||
// If left, unimplemented, NoOpInterceptor's default implementation will be used
|
||||
// But, for reference, this method is implemented
|
||||
}
|
||||
|
||||
// UnBindSocketReader performs cleanup when a reader is being removed.
|
||||
// Since the Interceptor don't maintain reader-specific state, no specific cleanup is needed.
|
||||
func (ping *Interceptor) UnBindSocketReader(_ interceptor.Reader) {
|
||||
// If left, unimplemented, NoOpInterceptor's default implementation will be used
|
||||
// But, for reference, this method is implemented
|
||||
}
|
||||
|
||||
// Close shuts down the ping interceptor and cleans up all resources.
|
||||
// It signals the background ping loop to stop, waits for confirmation
|
||||
// that it has stopped, and cleans up any remaining connection state.
|
||||
// This method is safe to call multiple times.
|
||||
func (ping *Interceptor) Close() error {
|
||||
select {
|
||||
case ping.close <- struct{}{}:
|
||||
// sent signal successfully
|
||||
default:
|
||||
// already closing/closed
|
||||
}
|
||||
|
||||
ping.mux.Lock()
|
||||
defer ping.mux.Unlock()
|
||||
ping.collection = make(map[*websocket.Conn]*collection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ping *Interceptor) loop() {
|
||||
ticker := time.NewTicker(ping.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ping.ctx.Done():
|
||||
return
|
||||
case <-ping.close:
|
||||
return
|
||||
case <-ticker.C:
|
||||
ping.mux.RLock()
|
||||
|
||||
// Send ping messages to all connections
|
||||
for conn, collection := range ping.collection {
|
||||
if collection.Writer == nil {
|
||||
fmt.Println("writer not bound yet; skipping...")
|
||||
continue
|
||||
}
|
||||
|
||||
msg := message.CreatePingMessage(time.Now())
|
||||
data, err := msg.Marshal()
|
||||
if err != nil {
|
||||
fmt.Println("error while marshaling ping message; skipping...")
|
||||
continue
|
||||
}
|
||||
|
||||
// Use the stored writer instead of sending through websocket.Conn.Write(...)
|
||||
if err := collection.Writer.Write(conn, websocket.MessageText, data); err != nil {
|
||||
fmt.Println("error while sending ping message; skipping...")
|
||||
continue
|
||||
}
|
||||
|
||||
// Record successful ping
|
||||
collection.pings.recordSentPing(msg)
|
||||
}
|
||||
|
||||
ping.mux.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ping *Interceptor) sendPing() {
|
||||
ping.mux.RLock()
|
||||
defer ping.mux.RUnlock()
|
||||
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type ping struct {
|
||||
rtt time.Duration // Round-trip time for ping-pong
|
||||
timestamp time.Time // When this ping was recorded
|
||||
}
|
||||
|
||||
type pings struct {
|
||||
pings []ping // Historical pings, limited by max capacity
|
||||
max uint16 // Maximum number of pings to keep
|
||||
recvd int // Total pongs received
|
||||
count int // Total pings sent
|
||||
recent ping // Most recent ping
|
||||
}
|
||||
|
||||
// recordPong updates pings based on a received pong message
|
||||
func (s *pings) recordPong(msg *message.Pong) {
|
||||
rtt := msg.Timestamp.Sub(msg.PingTimestamp)
|
||||
|
||||
newStat := ping{
|
||||
rtt: rtt,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
|
||||
s.recent = newStat
|
||||
|
||||
if uint16(len(s.pings)) >= s.max {
|
||||
if len(s.pings) > 0 {
|
||||
s.pings = s.pings[1:]
|
||||
}
|
||||
}
|
||||
s.pings = append(s.pings, newStat)
|
||||
s.recordReceivedPong(msg)
|
||||
}
|
||||
|
||||
// recordSentPing increments the count of pings sent
|
||||
func (s *pings) recordSentPing(_ *message.Ping) {
|
||||
s.count++
|
||||
}
|
||||
|
||||
func (s *pings) recordReceivedPong(_ *message.Pong) {
|
||||
s.recvd++
|
||||
}
|
||||
|
||||
// GetRecentRTT returns the most recent round-trip time
|
||||
func (s *pings) GetRecentRTT() time.Duration {
|
||||
return s.recent.rtt
|
||||
}
|
||||
|
||||
// GetAverageRTT calculates the average round-trip time
|
||||
func (s *pings) GetAverageRTT() time.Duration {
|
||||
if len(s.pings) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
for _, stat := range s.pings {
|
||||
total += stat.rtt
|
||||
}
|
||||
|
||||
return total / time.Duration(len(s.pings))
|
||||
}
|
||||
|
||||
// GetMaxRTT returns the maximum round-trip time observed
|
||||
func (s *pings) GetMaxRTT() time.Duration {
|
||||
if len(s.pings) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var maxRTT time.Duration
|
||||
for _, stat := range s.pings {
|
||||
if stat.rtt > maxRTT {
|
||||
maxRTT = stat.rtt
|
||||
}
|
||||
}
|
||||
|
||||
return maxRTT
|
||||
}
|
||||
|
||||
// GetMinRTT returns the minimum round-trip time observed
|
||||
func (s *pings) GetMinRTT() time.Duration {
|
||||
if len(s.pings) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
minRTT := s.pings[0].rtt
|
||||
for _, stat := range s.pings {
|
||||
if stat.rtt < minRTT {
|
||||
minRTT = stat.rtt
|
||||
}
|
||||
}
|
||||
|
||||
return minRTT
|
||||
}
|
||||
|
||||
// GetSuccessRate returns the percentage of successful pings
|
||||
func (s *pings) GetSuccessRate() float64 {
|
||||
if s.count == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return 100.0 * (1.0 - float64(s.count-s.recvd)/float64(s.count))
|
||||
}
|
||||
|
||||
type statsOption = func(*pings) error
|
||||
|
||||
func withMax(max uint16) statsOption {
|
||||
return func(s *pings) error {
|
||||
s.max = max
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type statsFactory struct {
|
||||
opts []statsOption
|
||||
}
|
||||
|
||||
func (factory *statsFactory) createStats() (*pings, error) {
|
||||
stats := &pings{
|
||||
pings: make([]ping, 0),
|
||||
max: ^uint16(0),
|
||||
count: 0,
|
||||
recvd: 0,
|
||||
recent: ping{},
|
||||
}
|
||||
|
||||
for _, option := range factory.opts {
|
||||
if err := option(stats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
package room
|
||||
@@ -0,0 +1 @@
|
||||
package room
|
||||
@@ -0,0 +1 @@
|
||||
package stats
|
||||
@@ -0,0 +1 @@
|
||||
package stats
|
||||
@@ -2,13 +2,10 @@ package message
|
||||
|
||||
type BaseMessage interface {
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
SourceID string `json:"source_id"`
|
||||
DestinationID string `json:"destination_id"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Header
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Ping struct {
|
||||
Header
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (ping *Ping) Marshal() ([]byte, error) {
|
||||
return json.Marshal(ping)
|
||||
}
|
||||
|
||||
func (ping *Ping) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, ping)
|
||||
}
|
||||
|
||||
func CreatePingMessage(timestamp time.Time) *Ping {
|
||||
return &Ping{
|
||||
Header: Header{
|
||||
SourceID: "server",
|
||||
DestinationID: "unknown",
|
||||
},
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
type Pong struct {
|
||||
Header
|
||||
PingTimestamp time.Time `json:"ping_timestamp"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (pong *Pong) Marshal() ([]byte, error) {
|
||||
return json.Marshal(pong)
|
||||
}
|
||||
|
||||
func (pong *Pong) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, pong)
|
||||
}
|
||||
|
||||
func CreatePongMessage(pingTimestamp time.Time, timestamp time.Time) *Pong {
|
||||
return &Pong{
|
||||
Header: Header{
|
||||
SourceID: "unknown",
|
||||
DestinationID: "server",
|
||||
},
|
||||
PingTimestamp: pingTimestamp,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
@@ -1 +1,16 @@
|
||||
package socket
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
|
||||
type clients = map[string]*client
|
||||
|
||||
type client struct {
|
||||
id string
|
||||
connection *websocket.Conn
|
||||
}
|
||||
|
||||
func createClient(connection *websocket.Conn) *client {
|
||||
return &client{
|
||||
connection: connection,
|
||||
}
|
||||
}
|
||||
|
||||
+77
-2
@@ -1,8 +1,83 @@
|
||||
package socket
|
||||
|
||||
type Settings struct {
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type apiSettings struct {
|
||||
}
|
||||
|
||||
func RegisterDefaultSettings(settings *Settings) error {
|
||||
func registerDefaultAPISettings(settings *apiSettings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type settings struct {
|
||||
// Server settings
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
ReadHeaderTimeout time.Duration
|
||||
MaxHeaderBytes int
|
||||
ShutdownTimeout time.Duration
|
||||
|
||||
// TLS configuration
|
||||
TLSConfig *tls.Config
|
||||
TLSCertFile string
|
||||
TLSKeyFile string
|
||||
|
||||
// Connection settings
|
||||
MaxConnections int
|
||||
ConnectionTimeout time.Duration
|
||||
|
||||
// WebSocket specific
|
||||
PingInterval time.Duration
|
||||
PongWait time.Duration
|
||||
WriteWait time.Duration
|
||||
MessageSizeLimit int64
|
||||
|
||||
// Router settings
|
||||
BasePath string
|
||||
EnableCORS bool
|
||||
CORSAllowOrigins []string
|
||||
CORSAllowMethods []string
|
||||
CORSAllowHeaders []string
|
||||
|
||||
// Middleware
|
||||
EnableLogging bool
|
||||
EnableCompression bool
|
||||
RateLimiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (s *settings) apply(socket *Socket) {
|
||||
socket.server.ReadTimeout = s.ReadTimeout
|
||||
socket.server.WriteTimeout = s.WriteTimeout
|
||||
socket.server.IdleTimeout = s.IdleTimeout
|
||||
socket.server.ReadHeaderTimeout = s.ReadHeaderTimeout
|
||||
socket.server.MaxHeaderBytes = s.MaxHeaderBytes
|
||||
|
||||
socket.server.TLSConfig = s.TLSConfig
|
||||
|
||||
if s.EnableCORS {
|
||||
// s.applyCORS()
|
||||
}
|
||||
|
||||
if s.EnableLogging {
|
||||
|
||||
}
|
||||
|
||||
if s.EnableCompression {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (s *settings) applyCORS(handler *http.HandlerFunc) {
|
||||
|
||||
}
|
||||
|
||||
func registerDefaultSettings(settings *settings) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
+72
-25
@@ -1,25 +1,24 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
settings *Settings
|
||||
settings *apiSettings
|
||||
interceptorRegistry *interceptor.Registry
|
||||
}
|
||||
|
||||
type APIOption = func(*API) error
|
||||
|
||||
func WithSocketSettings(settings *Settings) APIOption {
|
||||
return func(api *API) error {
|
||||
api.settings = settings
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithInterceptorRegistry(registry *interceptor.Registry) APIOption {
|
||||
return func(api *API) error {
|
||||
api.interceptorRegistry = registry
|
||||
@@ -27,55 +26,103 @@ func WithInterceptorRegistry(registry *interceptor.Registry) APIOption {
|
||||
}
|
||||
}
|
||||
|
||||
func CreateSocketFactory(options ...APIOption) (*API, error) {
|
||||
func CreateAPI(options ...APIOption) (*API, error) {
|
||||
api := &API{
|
||||
settings: nil,
|
||||
settings: &apiSettings{},
|
||||
interceptorRegistry: nil,
|
||||
}
|
||||
|
||||
if err := registerDefaultAPISettings(api.settings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
if err := option(api); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if api.settings == nil {
|
||||
api.settings = &Settings{}
|
||||
if err := RegisterDefaultSettings(api.settings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return api, nil
|
||||
}
|
||||
|
||||
func (api *API) CreateWebSocket(id string, options ...Option) (*Socket, error) {
|
||||
func (api *API) CreateWebSocket(ctx context.Context, id string, options ...Option) (*Socket, error) {
|
||||
socket := &Socket{
|
||||
id: id,
|
||||
id: id,
|
||||
settings: &settings{},
|
||||
socketAcceptOptions: &websocket.AcceptOptions{},
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
interceptors, err := api.interceptorRegistry.Build(id)
|
||||
interceptors, err := api.interceptorRegistry.Build(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
socket.interceptor = interceptors
|
||||
|
||||
if err := registerDefaultSettings(socket.settings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
if err := option(socket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return socket, nil
|
||||
return socket.setup(), nil
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
id string
|
||||
connections map[string]*websocket.Conn
|
||||
interceptor interceptor.Interceptor
|
||||
id string
|
||||
settings *settings
|
||||
server *http.Server
|
||||
router *http.ServeMux
|
||||
handlerFunc *http.HandlerFunc
|
||||
socketAcceptOptions *websocket.AcceptOptions
|
||||
interceptor interceptor.Interceptor
|
||||
mux sync.RWMutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (socket *Socket) Serve() {
|
||||
func (socket *Socket) setup() *Socket {
|
||||
socket.router = http.NewServeMux()
|
||||
socket.server = &http.Server{}
|
||||
// socket.handlerFunc = socket.wssHandler
|
||||
|
||||
socket.settings.apply(socket)
|
||||
|
||||
return socket
|
||||
}
|
||||
|
||||
func (socket *Socket) serve() error {
|
||||
defer socket.close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-socket.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
if err := socket.server.ListenAndServeTLS(socket.settings.TLSCertFile, socket.settings.TLSKeyFile); err != nil {
|
||||
fmt.Println(errors.New("error while serving HTTP server"))
|
||||
fmt.Println("trying again...")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (socket *Socket) baseHandler(w http.ResponseWriter, r *http.Request) {
|
||||
connection, err := websocket.Accept(w, r, socket.socketAcceptOptions)
|
||||
if err != nil {
|
||||
fmt.Println(errors.New("error while accepting socket connection"))
|
||||
}
|
||||
|
||||
if err := socket.interceptor.BindSocketConnection(connection); err != nil {
|
||||
fmt.Println("error while handling client:", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (socket *Socket) close() {
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user