general commit

This commit is contained in:
harshabose
2025-04-02 16:52:05 +05:30
parent f8d96fe32b
commit 9f9240cd88
13 changed files with 399 additions and 312 deletions
-216
View File
@@ -1,216 +0,0 @@
package encrypt
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"sync"
"time"
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
type collection struct {
encryptor cipher.AEAD // For encrypting outgoing messages
sessionID []byte
ctx context.Context
cancel context.CancelFunc
}
type Interceptor struct {
interceptor.NoOpInterceptor
collection map[*websocket.Conn]*collection
mux sync.Mutex
ctx context.Context
}
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
// Generate a key for this connection
key := make([]byte, 32) // AES-256
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return fmt.Errorf("failed to generate encryption key: %w", err)
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return fmt.Errorf("failed to create AES cipher: %w", err)
}
// Create GCM mode encryptor
gcm, err := cipher.NewGCM(block)
if err != nil {
return fmt.Errorf("failed to create GCM mode: %w", err)
}
sessionID := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, sessionID); err != nil {
fmt.Println("Failed to generate new session messageID:", err)
}
ctx, cancel := context.WithCancel(encrypt.ctx)
encrypt.collection[connection] = &collection{
encryptor: gcm,
sessionID: nil,
ctx: ctx,
cancel: cancel,
}
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
// TODO: Store different keys for encryption and decryption
go encrypt.loop(connection)
return nil
}
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
encrypt.mux.Lock()
collection, exists := encrypt.collection[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil {
// No encryption configured for this connection yet
// Pass through unencrypted
return writer.Write(conn, messageType, data)
}
// Generate a nonce for this message
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the data
encryptor := collection.encryptor
sessionID := collection.sessionID
encryptedData := encryptor.Seal(nil, nonce, data, sessionID)
// Format the encrypted message:
// [2-byte nonce length][nonce][encrypted data]
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
copy(finalData[2:], nonce)
copy(finalData[2+len(nonce):], encryptedData)
// Send the encrypted message
return writer.Write(conn, messageType, finalData)
})
}
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
// Read the encrypted message
messageType, encryptedData, err := reader.Read(conn)
if err != nil {
return messageType, encryptedData, err
}
encrypt.mux.Lock()
collection, exists := encrypt.collection[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil || len(encryptedData) < 2 {
// No decryption configured or data too short to be encrypted
// Pass through as-is
return messageType, encryptedData, nil
}
// Extract nonce length
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
// Ensure we have enough data for the nonce and at least some ciphertext
if len(encryptedData) < int(2+nonceLen) {
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
}
// Extract nonce and ciphertext
nonce := encryptedData[2 : 2+nonceLen]
ciphertext := encryptedData[2+nonceLen:]
// Decrypt the data
decryptor := collection.encryptor
sessionID := collection.sessionID
plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID)
if err != nil {
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
}
return messageType, plaintext, nil
})
}
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
collection, exists := encrypt.collection[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
collection.cancel()
delete(encrypt.collection, connection)
}
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
func (encrypt *Interceptor) Close() error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
encrypt.collection = make(map[*websocket.Conn]*collection)
return nil
}
func (encrypt *Interceptor) loop(connection *websocket.Conn) {
encrypt.mux.Lock()
collection, exists := encrypt.collection[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
ctx := collection.ctx
encrypt.mux.Unlock()
timer := time.NewTicker(5 * time.Minute)
defer timer.Stop()
for {
select {
case <-timer.C:
encrypt.mux.Lock()
newSessionID := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil {
fmt.Println("Failed to generate new session messageID:", err)
continue
}
if collection, exists := encrypt.collection[connection]; exists {
collection.sessionID = nil // Keep nil until sending to peer mechanism is set
}
// send the update sessionID to peer
encrypt.mux.Unlock()
case <-ctx.Done():
return
}
}
}
+23 -15
View File
@@ -4,40 +4,38 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"sync"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type encryptor interface {
Encrypt(message.Message) (*Message, error)
Decrypt(*Message) (message.Message, error)
Encrypt(message.Message) (*interceptor.BaseMessage, error)
Decrypt(*interceptor.BaseMessage) (message.Message, error)
Close() error
}
type aes256 struct {
key []byte
encryptor cipher.AEAD
sessionID []byte
mux sync.RWMutex
}
func createAES256() (*aes256, error) {
// generate key
key := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, err
}
// create cipher block
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// create GCM mode encryptor
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
@@ -49,12 +47,13 @@ func createAES256() (*aes256, error) {
}
return &aes256{
key: key,
encryptor: gcm,
sessionID: sessionID,
}, nil
}
func (a *aes256) Encrypt(message message.Message) (*Message, error) {
func (a *aes256) Encrypt(senderID, receiverID string, message message.Message) (*interceptor.BaseMessage, error) {
nonce := make([]byte, 12)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
@@ -69,19 +68,28 @@ func (a *aes256) Encrypt(message message.Message) (*Message, error) {
defer a.mux.Unlock()
encryptedData := a.encryptor.Seal(nil, nonce, data, a.sessionID)
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
payload := &Encrypted{Data: encryptedData, Nonce: nonce, Timestamp: time.Now()}
copy(finalData[2:], nonce)
copy(finalData[2+len(nonce):], encryptedData)
return CreateMessage(PayloadEncryptedType, finalData), nil
return CreateMessage(senderID, receiverID, payload)
}
func (a *aes256) Decrypt(message *Message) (message.Message, error) {
func (a *aes256) Decrypt(message *Encrypted) (message.Message, error) {
_, err := a.encryptor.Open(nil, message.Nonce, message.Data, a.sessionID)
if err != nil {
return nil, err
}
// TODO: figure out how to create a message here
return nil, nil
}
func (a *aes256) Close() error {
a.mux.Lock()
defer a.mux.Unlock()
a.sessionID = nil
a.encryptor = nil
return nil
}
+179 -1
View File
@@ -1,7 +1,185 @@
package encrypt
import "github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"sync"
"time"
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type Interceptor struct {
interceptor.NoOpInterceptor
states map[interceptor.Connection]*state
mux sync.Mutex
ctx context.Context
}
func (encrypt *Interceptor) BindSocketConnection(connection *websocket.Conn) error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
// TODO: Exchange keys with the peer using a key exchange protocol like Diffie-Hellman
// TODO: Store different keys for encryption and decryption
go encrypt.loop(connection)
return nil
}
func (encrypt *Interceptor) BindSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(connection interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
state, exists := encrypt.states[connection]
if !exists {
return
}
})
return interceptor.WriterFunc(func(conn *websocket.Conn, messageType websocket.MessageType, data []byte) error {
encrypt.mux.Lock()
collection, exists := encrypt.states[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil {
// No encryption configured for this connection yet
// Pass through unencrypted
return writer.Write(conn, messageType, data)
}
// Generate a nonce for this message
nonce := make([]byte, 12) // GCM typically uses a 12-byte nonce
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the data
encryptor := collection.encryptor
sessionID := collection.sessionID
encryptedData := encryptor.Seal(nil, nonce, data, sessionID)
// Format the encrypted message:
// [2-byte nonce length][nonce][encrypted data]
finalData := make([]byte, 2+len(nonce)+len(encryptedData))
binary.BigEndian.PutUint16(finalData[:2], uint16(len(nonce)))
copy(finalData[2:], nonce)
copy(finalData[2+len(nonce):], encryptedData)
// Send the encrypted message
return writer.Write(conn, messageType, finalData)
})
}
func (encrypt *Interceptor) BindSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn *websocket.Conn) (websocket.MessageType, []byte, error) {
// Read the encrypted message
messageType, encryptedData, err := reader.Read(conn)
if err != nil {
return messageType, encryptedData, err
}
encrypt.mux.Lock()
collection, exists := encrypt.states[conn]
encrypt.mux.Unlock()
if !exists || collection.encryptor == nil || len(encryptedData) < 2 {
// No decryption configured or data too short to be encrypted
// Pass through as-is
return messageType, encryptedData, nil
}
// Extract nonce length
nonceLen := binary.BigEndian.Uint16(encryptedData[:2])
// Ensure we have enough data for the nonce and at least some ciphertext
if len(encryptedData) < int(2+nonceLen) {
return messageType, encryptedData, fmt.Errorf("encrypted data too short")
}
// Extract nonce and ciphertext
nonce := encryptedData[2 : 2+nonceLen]
ciphertext := encryptedData[2+nonceLen:]
// Decrypt the data
decryptor := collection.encryptor
sessionID := collection.sessionID
plaintext, err := decryptor.Open(nil, nonce, ciphertext, sessionID)
if err != nil {
return messageType, encryptedData, fmt.Errorf("decryption failed: %w", err)
}
return messageType, plaintext, nil
})
}
func (encrypt *Interceptor) UnBindSocketConnection(connection *websocket.Conn) {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
collection, exists := encrypt.states[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
collection.cancel()
delete(encrypt.states, connection)
}
func (encrypt *Interceptor) UnBindSocketWriter(_ interceptor.Writer) {}
func (encrypt *Interceptor) UnBindSocketReader(_ interceptor.Reader) {}
func (encrypt *Interceptor) Close() error {
encrypt.mux.Lock()
defer encrypt.mux.Unlock()
encrypt.states = make(map[*websocket.Conn]*collection)
return nil
}
func (encrypt *Interceptor) loop(connection *websocket.Conn) {
encrypt.mux.Lock()
collection, exists := encrypt.states[connection]
if !exists {
fmt.Println("connection does not exists")
return
}
ctx := collection.ctx
encrypt.mux.Unlock()
timer := time.NewTicker(5 * time.Minute)
defer timer.Stop()
for {
select {
case <-timer.C:
encrypt.mux.Lock()
newSessionID := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, newSessionID); err != nil {
fmt.Println("Failed to generate new session messageID:", err)
continue
}
if collection, exists := encrypt.states[connection]; exists {
collection.sessionID = nil // Keep nil until sending to peer mechanism is set
}
// send the update sessionID to peer
encrypt.mux.Unlock()
case <-ctx.Done():
return
}
}
}
+59 -12
View File
@@ -2,22 +2,69 @@ package encrypt
import (
"encoding/json"
"errors"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type PayloadType string
var (
MainType interceptor.MainType = "encrypt"
const (
PayloadEncryptedType PayloadType = "encrypt:encrypted"
EncryptedSubType interceptor.SubType = "encrypted"
)
type Message struct {
Type PayloadType `json:"type"`
Payload json.RawMessage `json:"payload"`
}
func CreateMessage(_type PayloadType, payload json.RawMessage) *Message {
return &Message{
Type: _type,
Payload: payload,
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
if err != nil {
return nil, err
}
return &interceptor.BaseMessage{
BaseMessage: message.BaseMessage{
Header: message.Header{
SenderID: senderID,
ReceiverID: receiverID,
Protocol: interceptor.IProtocol,
},
Payload: data,
},
Header: interceptor.Header{
MainType: MainType,
SubType: payload.Type(),
},
}, nil
}
type Encrypted struct {
Data []byte `json:"data"`
Nonce []byte `json:"nonce"`
Timestamp time.Time `json:"timestamp"`
}
func (payload *Encrypted) Marshal() ([]byte, error) {
return json.Marshal(payload)
}
func (payload *Encrypted) Unmarshal(data []byte) error {
return json.Unmarshal(data, payload)
}
func (payload *Encrypted) Validate() error {
if payload.Data == nil || payload.Nonce == nil || len(payload.Data) <= 0 || len(payload.Nonce) <= 0 {
return errors.New("not valid")
}
return nil
}
func (payload *Encrypted) Process(header interceptor.Header, i interceptor.Interceptor, connection interceptor.Connection) error {
// TODO implement me
panic("implement me")
}
func (payload *Encrypted) Type() interceptor.SubType {
return EncryptedSubType
}
+9
View File
@@ -1 +1,10 @@
package encrypt
import "context"
type state struct {
id string
encryptor encryptor
cancel context.CancelFunc
ctx context.Context
}
+8 -6
View File
@@ -6,6 +6,8 @@ import (
"sync"
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
// Registry maintains a collection of interceptor factories that can be used to
@@ -110,29 +112,29 @@ type Writer interface {
// Write sends a message to the connection
// Takes the connection, message type, and message to write
// Returns any error encountered during writing
Write(conn Connection, messageType websocket.MessageType, message Message) error
Write(conn Connection, messageType websocket.MessageType, message message.Message) error
}
// Reader is an interface for reading messages from a websocket connection
type Reader interface {
// Read reads a message from the connection
// Returns the message type, message data, and any error
Read(conn Connection) (messageType websocket.MessageType, message Message, err error)
Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error)
}
// ReaderFunc is a function type that implements the Reader interface
type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message Message, err error)
type ReaderFunc func(conn Connection) (messageType websocket.MessageType, message message.Message, err error)
// Read implements the Reader interface for ReaderFunc
func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message Message, err error) {
func (f ReaderFunc) Read(conn Connection) (messageType websocket.MessageType, message message.Message, err error) {
return f(conn)
}
// WriterFunc is a function type that implements the Writer interface
type WriterFunc func(conn Connection, messageType websocket.MessageType, message Message) error
type WriterFunc func(conn Connection, messageType websocket.MessageType, message message.Message) error
// Write implements the Writer interface for WriterFunc
func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message Message) error {
func (f WriterFunc) Write(conn Connection, messageType websocket.MessageType, message message.Message) error {
return f(conn, messageType, message)
}
+8 -23
View File
@@ -1,37 +1,22 @@
package interceptor
import "encoding/json"
type Message interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
import (
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type (
Protocol string
SubType string
MainType string
)
type Header struct {
SenderID string `json:"source_id"`
ReceiverID string `json:"destination_id"`
Protocol Protocol `json:"protocol"`
MainType MainType `json:"main_type"`
SubType SubType `json:"sub_type"`
MainType MainType `json:"main_type"`
SubType SubType `json:"sub_type"`
}
var IProtocol Protocol = "interceptor"
var IProtocol message.Protocol = "interceptor"
type BaseMessage struct { // This actually needs to be interceptor module and Message interface should be in its own module
type BaseMessage struct {
Header
Payload json.RawMessage `json:"payload"`
}
func (msg *BaseMessage) Marshal() ([]byte, error) {
return json.Marshal(msg)
}
func (msg *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, msg)
message.BaseMessage
}
+13 -12
View File
@@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type Interceptor struct {
@@ -47,17 +48,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*Message)
if !ok || (msg.MainType != "ping" && msg.SubType != "ping") {
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
return writer.Write(conn, messageType, message)
}
payload := &Ping{}
if err := payload.Unmarshal(msg.Payload); err != nil {
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return writer.Write(conn, messageType, message)
}
@@ -72,7 +73,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
messageType, message, err = reader.Read(conn)
if err != nil {
return messageType, message, err
@@ -81,14 +82,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*Message)
if !ok {
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Protocol != interceptor.IProtocol && msg.MainType != MainType) {
return messageType, message, nil
}
payload := &Pong{}
if err := payload.Unmarshal(msg.Payload); err != nil {
return messageType, message, nil
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return messageType, message, err
}
if _, exists := i.states[conn]; exists {
@@ -161,7 +162,7 @@ func (i *Interceptor) loop(ctx context.Context, interval time.Duration, connecti
}
}
func (payload *Ping) Process(header interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
func (payload *Ping) Process(_ interceptor.Header, interceptor interceptor.Interceptor, connection interceptor.Connection) error {
if err := payload.Validate(); err != nil {
return err
}
+37 -10
View File
@@ -2,29 +2,48 @@ package ping
import (
"encoding/json"
"errors"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
type Message struct {
interceptor.BaseMessage
var (
MainType interceptor.MainType = "ping-pong"
PingSubType interceptor.SubType = "ping"
PongSubType interceptor.SubType = "pong"
subTypeMap = map[interceptor.SubType]interceptor.Payload{
PingSubType: &Ping{},
PongSubType: &Pong{},
}
)
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
if payload, exists := subTypeMap[sub]; exists {
if err := payload.Unmarshal(p); err != nil {
return nil, err
}
return payload, nil
}
return nil, errors.New("processor does not exist for given type")
}
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*Message, error) {
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
if err != nil {
return nil, err
}
return &Message{
BaseMessage: interceptor.BaseMessage{
Header: interceptor.Header{
SenderID: senderID,
ReceiverID: receiverID,
},
Payload: data,
return &interceptor.BaseMessage{
Header: interceptor.Header{
SenderID: senderID,
ReceiverID: receiverID,
MainType: MainType,
SubType: payload.Type(),
},
Payload: data,
}, nil
}
@@ -69,6 +88,10 @@ func (payload *Ping) Validate() error {
return nil
}
func (payload *Ping) Type() interceptor.SubType {
return PingSubType
}
// Pong represents a response to a ping message, confirming connection health.
// It contains the original ping's message ID and timestamp, plus its own timestamp,
// allowing the server to calculate the round-trip time.
@@ -110,3 +133,7 @@ func (payload *Pong) Unmarshal(data []byte) error {
func (payload *Pong) Validate() error {
return nil
}
func (payload *Pong) Type() interceptor.SubType {
return PongSubType
}
+10 -9
View File
@@ -9,6 +9,7 @@ import (
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type Interceptor struct {
@@ -43,17 +44,17 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
}
func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) interceptor.Writer {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message interceptor.Message) error {
return interceptor.WriterFunc(func(conn interceptor.Connection, messageType websocket.MessageType, message message.Message) error {
i.Mutex.Lock()
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
return writer.Write(conn, messageType, message)
}
payload := &Pong{}
if err := payload.Unmarshal(msg.Payload); err != nil {
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return writer.Write(conn, messageType, message)
}
@@ -68,7 +69,7 @@ func (i *Interceptor) InterceptSocketWriter(writer interceptor.Writer) intercept
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message interceptor.Message, err error) {
return interceptor.ReaderFunc(func(conn interceptor.Connection) (messageType websocket.MessageType, message message.Message, err error) {
messageType, message, err = reader.Read(conn)
if err != nil {
return messageType, message, err
@@ -78,13 +79,13 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
defer i.Mutex.Unlock()
msg, ok := message.(*interceptor.BaseMessage)
if !ok || (msg.Header.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
if !ok || (msg.Protocol != interceptor.IProtocol && msg.Header.MainType != MainType) {
return messageType, message, nil
}
payload := &Ping{}
if err := payload.Unmarshal(msg.Payload); err != nil {
return messageType, message, nil
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return messageType, message, err
}
if _, exists := i.states[conn]; exists {
+23 -1
View File
@@ -2,12 +2,34 @@ package pong
import (
"encoding/json"
"errors"
"time"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
)
var MainType interceptor.MainType = "pong"
var (
MainType interceptor.MainType = "ping-pong"
PingSubType interceptor.SubType = "ping"
PongSubType interceptor.SubType = "pong"
subTypeMap = map[interceptor.SubType]interceptor.Payload{
PingSubType: &Ping{},
PongSubType: &Pong{},
}
)
func PayloadUnmarshal(sub interceptor.SubType, p json.RawMessage) (interceptor.Payload, error) {
if payload, exists := subTypeMap[sub]; exists {
if err := payload.Unmarshal(p); err != nil {
return nil, err
}
return payload, nil
}
return nil, errors.New("processor does not exist for given type")
}
func CreateMessage(senderID, receiverID string, payload interceptor.Payload) (*interceptor.BaseMessage, error) {
data, err := payload.Marshal()
+7 -7
View File
@@ -8,6 +8,7 @@ import (
"github.com/coder/websocket"
"github.com/harshabose/skyline_sonata/serve/pkg/interceptor"
"github.com/harshabose/skyline_sonata/serve/pkg/message"
)
type Interceptor struct {
@@ -30,7 +31,7 @@ func (i *Interceptor) BindSocketConnection(connection interceptor.Connection, wr
}
func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) interceptor.Reader {
return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, interceptor.Message, error) {
return interceptor.ReaderFunc(func(connection interceptor.Connection) (websocket.MessageType, message.Message, error) {
messageType, data, err := reader.Read(connection)
if err != nil {
return messageType, data, err
@@ -44,15 +45,14 @@ func (i *Interceptor) InterceptSocketReader(reader interceptor.Reader) intercept
i.Mutex.Lock()
defer i.Mutex.Unlock()
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
return messageType, data, nil
}
if _, exists := i.states[connection]; exists {
payload, err := PayloadUnmarshal(msg.SubType, msg.Payload)
if err != nil {
fmt.Println("error while processing room message: ", err.Error())
return messageType, data, nil
}
if err := payload.Process(msg.Header, i, connection); err != nil {
fmt.Println("error while processing room message: ", err.Error())
return messageType, data, nil
}
}
+23
View File
@@ -1,6 +1,29 @@
package message
import "encoding/json"
type Protocol string
type Message interface {
Marshal() ([]byte, error)
Unmarshal([]byte) error
}
type Header struct {
SenderID string `json:"source_id"`
ReceiverID string `json:"destination_id"`
Protocol Protocol `json:"protocol"`
}
type BaseMessage struct {
Header
Payload json.RawMessage `json:"payload,omitempty"`
}
func (msg *BaseMessage) Marshal() ([]byte, error) {
return json.Marshal(msg)
}
func (msg *BaseMessage) Unmarshal(data []byte) error {
return json.Unmarshal(data, msg)
}