mirror of
https://github.com/harshabose/serve.git
synced 2026-04-22 23:07:27 +08:00
general commit
This commit is contained in:
@@ -1,216 +0,0 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
type collection struct {
|
||||
encryptor cipher.AEAD // For encrypting outgoing messages
|
||||
sessionID []byte
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
collection map[*websocket.Conn]*collection
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
// Generate a key for this connection
|
||||
key := make([]byte, 32) // AES-256
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return fmt.Errorf("failed to generate encryption key: %w", err)
|
||||
}
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode encryptor
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
sessionID := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, sessionID); err != nil {
|
||||
fmt.Println("Failed to generate new session messageID:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(encrypt.ctx)
|
||||
encrypt.collection[connection] = &collection{
|
||||
encryptor: gcm,
|
||||
sessionID: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
|
||||
// TODO: Store different keys for encryption and decryption
|
||||
|
||||
go encrypt.loop(connection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.collection[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || collection.encryptor == nil {
|
||||
// No encryption configured for this connection yet
|
||||
// Pass through unencrypted
|
||||
return writer.Write(conn, messageType, data)
|
||||
}
|
||||
|
||||
// Generate a nonce for this message
|
||||
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
encryptor := collection.encryptor
|
||||
sessionID := collection.sessionID
|
||||
encryptedData := encryptor.Seal(nil, nonce, data, sessionID)
|
||||
|
||||
// Format the encrypted message:
|
||||
// [2-byte nonce length][nonce][encrypted data]
|
||||
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
|
||||
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
|
||||
copy(finalData[2:], nonce)
|
||||
copy(finalData[2+len(nonce):], encryptedData)
|
||||
|
||||
// Send the encrypted message
|
||||
return writer.Write(conn, messageType, finalData)
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
|
||||
// Read the encrypted message
|
||||
messageType, encryptedData, err := reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, err
|
||||
}
|
||||
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.collection[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || collection.encryptor == nil || len(encryptedData) < 2 {
|
||||
// No decryption configured or data too short to be encrypted
|
||||
// Pass through as-is
|
||||
return messageType, encryptedData, nil
|
||||
}
|
||||
|
||||
// Extract nonce length
|
||||
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
|
||||
|
||||
// Ensure we have enough data for the nonce and at least some ciphertext
|
||||
if len(encryptedData) < int(2+nonceLen) {
|
||||
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce := encryptedData[2 : 2+nonceLen]
|
||||
ciphertext := encryptedData[2+nonceLen:]
|
||||
|
||||
// Decrypt the data
|
||||
decryptor := collection.encryptor
|
||||
sessionID := collection.sessionID
|
||||
plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return messageType, plaintext, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
collection, exists := encrypt.collection[connection]
|
||||
if !exists {
|
||||
fmt.Println("connection does not exists")
|
||||
return
|
||||
}
|
||||
|
||||
collection.cancel()
|
||||
delete(encrypt.collection, connection)
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
|
||||
|
||||
func (encrypt *Interceptor) Close() error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
encrypt.collection = make(map[*websocket.Conn]*collection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) loop(connection *websocket.Conn) {
|
||||
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.collection[connection]
|
||||
if !exists {
|
||||
fmt.Println("connection does not exists")
|
||||
return
|
||||
}
|
||||
ctx := collection.ctx
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
timer := time.NewTicker(5 * time.Minute)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
encrypt.mux.Lock()
|
||||
|
||||
newSessionID := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil {
|
||||
fmt.Println("Failed to generate new session messageID:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if collection, exists := encrypt.collection[connection]; exists {
|
||||
collection.sessionID = nil // Keep nil until sending to peer mechanism is set
|
||||
}
|
||||
|
||||
// send the update sessionID to peer
|
||||
|
||||
encrypt.mux.Unlock()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,40 +4,38 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type encryptor interface {
|
||||
Encrypt(message.Message) (*Message, error)
|
||||
Decrypt(*Message) (message.Message, error)
|
||||
Encrypt(message.Message) (*interceptor.BaseMessage, error)
|
||||
Decrypt(*interceptor.BaseMessage) (message.Message, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type aes256 struct {
|
||||
key []byte
|
||||
encryptor cipher.AEAD
|
||||
sessionID []byte
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
func createAES256() (*aes256, error) {
|
||||
// generate key
|
||||
key := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create GCM mode encryptor
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -49,12 +47,13 @@ func createAES256() (*aes256, error) {
|
||||
}
|
||||
|
||||
return &aes256{
|
||||
key: key,
|
||||
encryptor: gcm,
|
||||
sessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *aes256) Encrypt(message message.Message) (*Message, error) {
|
||||
func (a *aes256) Encrypt(senderID, receiverID string, message message.Message) (*interceptor.BaseMessage, error) {
|
||||
nonce := make([]byte, 12)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
@@ -69,19 +68,28 @@ func (a *aes256) Encrypt(message message.Message) (*Message, error) {
|
||||
defer a.mux.Unlock()
|
||||
|
||||
encryptedData := a.encryptor.Seal(nil, nonce, data, a.sessionID)
|
||||
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
|
||||
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
|
||||
payload := &Encrypted{Data: encryptedData, Nonce: nonce, Timestamp: time.Now()}
|
||||
|
||||
copy(finalData[2:], nonce)
|
||||
copy(finalData[2+len(nonce):], encryptedData)
|
||||
|
||||
return CreateMessage(PayloadEncryptedType, finalData), nil
|
||||
return CreateMessage(senderID, receiverID, payload)
|
||||
}
|
||||
|
||||
func (a *aes256) Decrypt(message *Message) (message.Message, error) {
|
||||
func (a *aes256) Decrypt(message *Encrypted) (message.Message, error) {
|
||||
_, err := a.encryptor.Open(nil, message.Nonce, message.Data, a.sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: figure out how to create a message here
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *aes256) Close() error {
|
||||
a.mux.Lock()
|
||||
defer a.mux.Unlock()
|
||||
|
||||
a.sessionID = nil
|
||||
a.encryptor = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,185 @@
|
||||
package encrypt
|
||||
|
||||
import "github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
interceptor.NoOpInterceptor
|
||||
states map[interceptor.Connection]*state
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
|
||||
// TODO: Store different keys for encryption and decryption
|
||||
|
||||
go encrypt.loop(connection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(connection interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
|
||||
state, exists := encrypt.states[connection]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
})
|
||||
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.states[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || collection.encryptor == nil {
|
||||
// No encryption configured for this connection yet
|
||||
// Pass through unencrypted
|
||||
return writer.Write(conn, messageType, data)
|
||||
}
|
||||
|
||||
// Generate a nonce for this message
|
||||
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
encryptor := collection.encryptor
|
||||
sessionID := collection.sessionID
|
||||
encryptedData := encryptor.Seal(nil, nonce, data, sessionID)
|
||||
|
||||
// Format the encrypted message:
|
||||
// [2-byte nonce length][nonce][encrypted data]
|
||||
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
|
||||
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
|
||||
copy(finalData[2:], nonce)
|
||||
copy(finalData[2+len(nonce):], encryptedData)
|
||||
|
||||
// Send the encrypted message
|
||||
return writer.Write(conn, messageType, finalData)
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
|
||||
// Read the encrypted message
|
||||
messageType, encryptedData, err := reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, err
|
||||
}
|
||||
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.states[conn]
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
if !exists || collection.encryptor == nil || len(encryptedData) < 2 {
|
||||
// No decryption configured or data too short to be encrypted
|
||||
// Pass through as-is
|
||||
return messageType, encryptedData, nil
|
||||
}
|
||||
|
||||
// Extract nonce length
|
||||
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
|
||||
|
||||
// Ensure we have enough data for the nonce and at least some ciphertext
|
||||
if len(encryptedData) < int(2+nonceLen) {
|
||||
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce := encryptedData[2 : 2+nonceLen]
|
||||
ciphertext := encryptedData[2+nonceLen:]
|
||||
|
||||
// Decrypt the data
|
||||
decryptor := collection.encryptor
|
||||
sessionID := collection.sessionID
|
||||
plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID)
|
||||
if err != nil {
|
||||
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return messageType, plaintext, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
collection, exists := encrypt.states[connection]
|
||||
if !exists {
|
||||
fmt.Println("connection does not exists")
|
||||
return
|
||||
}
|
||||
|
||||
collection.cancel()
|
||||
delete(encrypt.states, connection)
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
|
||||
|
||||
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
|
||||
|
||||
func (encrypt *Interceptor) Close() error {
|
||||
encrypt.mux.Lock()
|
||||
defer encrypt.mux.Unlock()
|
||||
|
||||
encrypt.states = make(map[*websocket.Conn]*collection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (encrypt *Interceptor) loop(connection *websocket.Conn) {
|
||||
|
||||
encrypt.mux.Lock()
|
||||
collection, exists := encrypt.states[connection]
|
||||
if !exists {
|
||||
fmt.Println("connection does not exists")
|
||||
return
|
||||
}
|
||||
ctx := collection.ctx
|
||||
encrypt.mux.Unlock()
|
||||
|
||||
timer := time.NewTicker(5 * time.Minute)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
encrypt.mux.Lock()
|
||||
|
||||
newSessionID := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil {
|
||||
fmt.Println("Failed to generate new session messageID:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if collection, exists := encrypt.states[connection]; exists {
|
||||
collection.sessionID = nil // Keep nil until sending to peer mechanism is set
|
||||
}
|
||||
|
||||
// send the update sessionID to peer
|
||||
|
||||
encrypt.mux.Unlock()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,22 +2,69 @@ package encrypt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type PayloadType string
|
||||
var (
|
||||
MainType interceptor.MainType = "encrypt"
|
||||
|
||||
const (
|
||||
PayloadEncryptedType PayloadType = "encrypt:encrypted"
|
||||
EncryptedSubType interceptor.SubType = "encrypted"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Type PayloadType `json:"type"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
func CreateMessage(_type PayloadType, payload json.RawMessage) *Message {
|
||||
return &Message{
|
||||
Type: _type,
|
||||
Payload: payload,
|
||||
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
|
||||
data, err := payload.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &interceptor.BaseMessage{
|
||||
BaseMessage: message.BaseMessage{
|
||||
Header: message.Header{
|
||||
SenderID: senderID,
|
||||
ReceiverID: receiverID,
|
||||
Protocol: interceptor.IProtocol,
|
||||
},
|
||||
Payload: data,
|
||||
},
|
||||
Header: interceptor.Header{
|
||||
|
||||
MainType: MainType,
|
||||
SubType: payload.Type(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Encrypted struct {
|
||||
Data []byte `json:"data"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (payload *Encrypted) Marshal() ([]byte, error) {
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func (payload *Encrypted) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, payload)
|
||||
}
|
||||
|
||||
func (payload *Encrypted) Validate() error {
|
||||
if payload.Data == nil || payload.Nonce == nil || len(payload.Data) <= 0 || len(payload.Nonce) <= 0 {
|
||||
return errors.New("not valid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (payload *Encrypted) Process(header interceptor.Header, i interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (payload *Encrypted) Type() interceptor.SubType {
|
||||
return EncryptedSubType
|
||||
}
|
||||
|
||||
@@ -1 +1,10 @@
|
||||
package encrypt
|
||||
|
||||
import "context"
|
||||
|
||||
type state struct {
|
||||
id string
|
||||
encryptor encryptor
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
// Registry maintains a collection of interceptor factories that can be used to
|
||||
@@ -110,29 +112,29 @@ type Writer interface {
|
||||
// Write sends a message to the connection
|
||||
// Takes the connection, message type, and message to write
|
||||
// Returns any error encountered during writing
|
||||
Write(conn Connection, messageType websocket.MessageType, message Message) error
|
||||
Write(conn Connection, messageType websocket.MessageType, message message.Message) error
|
||||
}
|
||||
|
||||
// Reader is an interface for reading messages from a websocket connection
|
||||
type Reader interface {
|
||||
// Read reads a message from the connection
|
||||
// Returns the message type, message data, and any error
|
||||
Read(conn Connection) (messageType websocket.MessageType, message Message, err error)
|
||||
Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error)
|
||||
}
|
||||
|
||||
// ReaderFunc is a function type that implements the Reader interface
|
||||
type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message Message, err error)
|
||||
type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message message.Message, err error)
|
||||
|
||||
// Read implements the Reader interface for ReaderFunc
|
||||
func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message Message, err error) {
|
||||
func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error) {
|
||||
return f(conn)
|
||||
}
|
||||
|
||||
// WriterFunc is a function type that implements the Writer interface
|
||||
type WriterFunc func(conn Connection, messageType websocket.MessageType, message Message) error
|
||||
type WriterFunc func(conn Connection, messageType websocket.MessageType, message message.Message) error
|
||||
|
||||
// Write implements the Writer interface for WriterFunc
|
||||
func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message Message) error {
|
||||
func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message message.Message) error {
|
||||
return f(conn, messageType, message)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,37 +1,22 @@
|
||||
package interceptor
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Message interface {
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
import (
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type (
|
||||
Protocol string
|
||||
SubType string
|
||||
MainType string
|
||||
)
|
||||
|
||||
type Header struct {
|
||||
SenderID string `json:"source_id"`
|
||||
ReceiverID string `json:"destination_id"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
MainType MainType `json:"main_type"`
|
||||
SubType SubType `json:"sub_type"`
|
||||
MainType MainType `json:"main_type"`
|
||||
SubType SubType `json:"sub_type"`
|
||||
}
|
||||
|
||||
var IProtocol Protocol = "interceptor"
|
||||
var IProtocol message.Protocol = "interceptor"
|
||||
|
||||
type BaseMessage struct { // This actually needs to be interceptor module and Message interface should be in its own module
|
||||
type BaseMessage struct {
|
||||
Header
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
func (msg *BaseMessage) Marshal() ([]byte, error) {
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func (msg *BaseMessage) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, msg)
|
||||
message.BaseMessage
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
@@ -47,17 +48,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
msg, ok := message.(*Message)
|
||||
if !ok || (msg.MainType != "ping" && msg.SubType != "ping") {
|
||||
msg, ok := message.(*interceptor.BaseMessage)
|
||||
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
|
||||
return writer.Write(conn, messageType, message)
|
||||
}
|
||||
|
||||
payload := &Ping{}
|
||||
if err := payload.Unmarshal(msg.Payload); err != nil {
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
return writer.Write(conn, messageType, message)
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) {
|
||||
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
|
||||
messageType, message, err = reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, message, err
|
||||
@@ -81,14 +82,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
msg, ok := message.(*Message)
|
||||
if !ok {
|
||||
msg, ok := message.(*interceptor.BaseMessage)
|
||||
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
|
||||
return messageType, message, nil
|
||||
}
|
||||
|
||||
payload := &Pong{}
|
||||
if err := payload.Unmarshal(msg.Payload); err != nil {
|
||||
return messageType, message, nil
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
return messageType, message, err
|
||||
}
|
||||
|
||||
if _, exists := i.states[conn]; exists {
|
||||
@@ -161,7 +162,7 @@ func (i *Interceptor) loop(ctx context.Context, interval time.Duration, connecti
|
||||
}
|
||||
}
|
||||
|
||||
func (payload *Ping) Process(header interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
func (payload *Ping) Process(_ interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
|
||||
if err := payload.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,29 +2,48 @@ package ping
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
interceptor.BaseMessage
|
||||
var (
|
||||
MainType interceptor.MainType = "ping-pong"
|
||||
PingSubType interceptor.SubType = "ping"
|
||||
PongSubType interceptor.SubType = "pong"
|
||||
|
||||
subTypeMap = map[interceptor.SubType]interceptor.Payload{
|
||||
PingSubType: &Ping{},
|
||||
PongSubType: &Pong{},
|
||||
}
|
||||
)
|
||||
|
||||
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
|
||||
if payload, exists := subTypeMap[sub]; exists {
|
||||
if err := payload.Unmarshal(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("processor does not exist for given type")
|
||||
}
|
||||
|
||||
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*Message, error) {
|
||||
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
|
||||
data, err := payload.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Message{
|
||||
BaseMessage: interceptor.BaseMessage{
|
||||
Header: interceptor.Header{
|
||||
SenderID: senderID,
|
||||
ReceiverID: receiverID,
|
||||
},
|
||||
Payload: data,
|
||||
return &interceptor.BaseMessage{
|
||||
Header: interceptor.Header{
|
||||
SenderID: senderID,
|
||||
ReceiverID: receiverID,
|
||||
MainType: MainType,
|
||||
SubType: payload.Type(),
|
||||
},
|
||||
Payload: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -69,6 +88,10 @@ func (payload *Ping) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (payload *Ping) Type() interceptor.SubType {
|
||||
return PingSubType
|
||||
}
|
||||
|
||||
// Pong represents a response to a ping message, confirming connection health.
|
||||
// It contains the original ping's message ID and timestamp, plus its own timestamp,
|
||||
// allowing the server to calculate the round-trip time.
|
||||
@@ -110,3 +133,7 @@ func (payload *Pong) Unmarshal(data []byte) error {
|
||||
func (payload *Pong) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (payload *Pong) Type() interceptor.SubType {
|
||||
return PongSubType
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
@@ -43,17 +44,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error {
|
||||
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
msg, ok := message.(*interceptor.BaseMessage)
|
||||
if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
|
||||
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
|
||||
return writer.Write(conn, messageType, message)
|
||||
}
|
||||
|
||||
payload := &Pong{}
|
||||
if err := payload.Unmarshal(msg.Payload); err != nil {
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
return writer.Write(conn, messageType, message)
|
||||
}
|
||||
|
||||
@@ -68,7 +69,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) {
|
||||
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
|
||||
messageType, message, err = reader.Read(conn)
|
||||
if err != nil {
|
||||
return messageType, message, err
|
||||
@@ -78,13 +79,13 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
msg, ok := message.(*interceptor.BaseMessage)
|
||||
if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
|
||||
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
|
||||
return messageType, message, nil
|
||||
}
|
||||
|
||||
payload := &Ping{}
|
||||
if err := payload.Unmarshal(msg.Payload); err != nil {
|
||||
return messageType, message, nil
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
return messageType, message, err
|
||||
}
|
||||
|
||||
if _, exists := i.states[conn]; exists {
|
||||
|
||||
@@ -2,12 +2,34 @@ package pong
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
)
|
||||
|
||||
var MainType interceptor.MainType = "pong"
|
||||
var (
|
||||
MainType interceptor.MainType = "ping-pong"
|
||||
|
||||
PingSubType interceptor.SubType = "ping"
|
||||
PongSubType interceptor.SubType = "pong"
|
||||
|
||||
subTypeMap = map[interceptor.SubType]interceptor.Payload{
|
||||
PingSubType: &Ping{},
|
||||
PongSubType: &Pong{},
|
||||
}
|
||||
)
|
||||
|
||||
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
|
||||
if payload, exists := subTypeMap[sub]; exists {
|
||||
if err := payload.Unmarshal(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("processor does not exist for given type")
|
||||
}
|
||||
|
||||
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
|
||||
data, err := payload.Marshal()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
|
||||
"github.com/harshabose/skyline_sonata/serve/pkg/message"
|
||||
)
|
||||
|
||||
type Interceptor struct {
|
||||
@@ -30,7 +31,7 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
|
||||
}
|
||||
|
||||
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
|
||||
return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, interceptor.Message, error) {
|
||||
return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, message.Message, error) {
|
||||
messageType, data, err := reader.Read(connection)
|
||||
if err != nil {
|
||||
return messageType, data, err
|
||||
@@ -44,15 +45,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
|
||||
i.Mutex.Lock()
|
||||
defer i.Mutex.Unlock()
|
||||
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
return messageType, data, nil
|
||||
}
|
||||
|
||||
if _, exists := i.states[connection]; exists {
|
||||
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
|
||||
if err != nil {
|
||||
fmt.Println("error while processing room message: ", err.Error())
|
||||
return messageType, data, nil
|
||||
}
|
||||
|
||||
if err := payload.Process(msg.Header, i, connection); err != nil {
|
||||
fmt.Println("error while processing room message: ", err.Error())
|
||||
return messageType, data, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,29 @@
|
||||
package message
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Protocol string
|
||||
|
||||
type Message interface {
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
SenderID string `json:"source_id"`
|
||||
ReceiverID string `json:"destination_id"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
}
|
||||
|
||||
type BaseMessage struct {
|
||||
Header
|
||||
Payload json.RawMessage `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
func (msg *BaseMessage) Marshal() ([]byte, error) {
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func (msg *BaseMessage) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, msg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user