diff --git a/attributes.go b/attributes.go index ebba38a..ed4a9b6 100644 --- a/attributes.go +++ b/attributes.go @@ -161,3 +161,32 @@ func (m *Message) Get(t AttrType) ([]byte, error) { } return v.Value, nil } + +// AttrLengthError occurs when len(v) > Max. +type AttrLengthError struct { + Type AttrType + Max int + Got int +} + +func (e AttrLengthError) Error() string { + return fmt.Sprintf("Length of %s attribute %d exceeds maximum %d", + e.Type, e.Got, e.Max, + ) +} + +// STUN aligns attributes on 32-bit boundaries, attributes whose content +// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of +// padding so that its value contains a multiple of 4 bytes. The +// padding bits are ignored, and may be any value. +// +// https://tools.ietf.org/html/rfc5389#section-15 +const padding = 4 + +func nearestPaddedValueLength(l int) int { + n := padding * (l / padding) + if n < l { + n += padding + } + return n +} diff --git a/attributes_test.go b/attributes_test.go index dcb2054..b04ce76 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -34,3 +34,28 @@ func TestMessage_GetNoAllocs(t *testing.T) { } }) } + +func TestPadding(t *testing.T) { + tt := []struct { + in, out int + }{ + {4, 4}, // 0 + {2, 4}, // 1 + {5, 8}, // 2 + {8, 8}, // 3 + {11, 12}, // 4 + {1, 4}, // 5 + {3, 4}, // 6 + {6, 8}, // 7 + {7, 8}, // 8 + {0, 0}, // 9 + {40, 40}, // 10 + } + for i, c := range tt { + if got := nearestPaddedValueLength(c.in); got != c.out { + t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", + i, c.in, got, c.out, + ) + } + } +} diff --git a/fingerprint.go b/fingerprint.go index 27ee52c..cedcf07 100644 --- a/fingerprint.go +++ b/fingerprint.go @@ -68,7 +68,10 @@ func (FingerprintAttr) Check(m *Message) error { return err } if len(b) != fingerprintSize { - return newDecodeErr("message", "fingerprint", "bad length") + return newDecodeErr("message", + "fingerprint", + "bad length", + ) } val := bin.Uint32(b) attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) diff --git a/helpers.go b/helpers.go index e311780..2fc35d3 100644 --- a/helpers.go +++ b/helpers.go @@ -1,22 +1,38 @@ package stun -// Setter sets *Message attribute. -type Setter interface { - AddTo(m *Message) error -} +// Interfaces that are implemented by message attributes, shorthands for them, +// or helpers for message fields as type or transaction id. +type ( + // Setter sets *Message attribute. + Setter interface { + AddTo(m *Message) error + } + // Getter parses attribute from *Message. + Getter interface { + GetFrom(m *Message) error + } + // Checker checks *Message attribute. + Checker interface { + Check(m *Message) error + } +) -// Getter decodes *Message attribute. -type Getter interface { - GetFrom(m *Message) error -} - -// Checker checks *Message attribute. -type Checker interface { - Check(m *Message) error -} - -// Build applies setters to message. -func (m *Message) Build(setters... Setter) error { +// Build resets message and applies setters to it in batch, returning on +// first error. To prevent allocations, pass pointers to values. +// +// Example: +// var ( +// t = BindingRequest +// username = NewUsername("username") +// nonce = NewNonce("nonce") +// realm = NewRealm("example.org") +// ) +// m := new(Message) +// m.Build(t, username, nonce, realm) // 4 allocations +// m.Build(&t, &username, &nonce, &realm) // 0 allocations +// +// See BenchmarkBuildOverhead. +func (m *Message) Build(setters ...Setter) error { m.Reset() m.WriteHeader() for _, s := range setters { @@ -27,7 +43,8 @@ func (m *Message) Build(setters... Setter) error { return nil } -func (m *Message) Check(checkers... Checker) error { +// Checker applies chereks to message in batch, returning on first error. +func (m *Message) Check(checkers ...Checker) error { for _, c := range checkers { if err := c.Check(m); err != nil { return err @@ -36,8 +53,18 @@ func (m *Message) Check(checkers... Checker) error { return nil } +// Parse applies getters to message in batch, returning on first error. +func (m *Message) Parse(getters ...Getter) error { + for _, c := range getters { + if err := c.GetFrom(m); err != nil { + return err + } + } + return nil +} + // Build wraps Message.Build method. -func Build(setters... Setter) (*Message, error) { +func Build(setters ...Setter) (*Message, error) { m := new(Message) return m, m.Build(setters...) } diff --git a/helpers_test.go b/helpers_test.go index d760b89..619f242 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -2,14 +2,48 @@ package stun import "testing" +func BenchmarkBuildOverhead(b *testing.B) { + var ( + t = BindingRequest + username = NewUsername("username") + nonce = NewNonce("nonce") + realm = NewRealm("example.org") + ) + b.Run("Build", func(b *testing.B) { + b.ReportAllocs() + m := new(Message) + for i := 0; i < b.N; i++ { + m.Build(&t, &username, &nonce, &realm, &Fingerprint) + } + }) + b.Run("BuildNonPointer", func(b *testing.B) { + b.ReportAllocs() + m := new(Message) + for i := 0; i < b.N; i++ { + m.Build(t, username, nonce, realm, Fingerprint) + } + }) + b.Run("Raw", func(b *testing.B) { + b.ReportAllocs() + m := new(Message) + for i := 0; i < b.N; i++ { + m.Reset() + m.WriteHeader() + m.SetType(t) + username.AddTo(m) + nonce.AddTo(m) + realm.AddTo(m) + Fingerprint.AddTo(m) + } + }) +} + func TestMessage_Apply(t *testing.T) { var ( integrity = NewShortTermIntegrity("password") - decoded = new(Message) + decoded = new(Message) ) - m, err := Build( - NewType(ClassRequest, MethodBinding), - TransactionID, + m, err := Build(BindingRequest, TransactionID, NewUsername("username"), NewNonce("nonce"), NewRealm("example.org"), @@ -28,7 +62,6 @@ func TestMessage_Apply(t *testing.T) { if !decoded.Equal(m) { t.Error("not equal") } - if err := integrity.Check(decoded); err != nil { t.Fatal(err) } diff --git a/integrity.go b/integrity.go index 2609f7b..2f2e139 100644 --- a/integrity.go +++ b/integrity.go @@ -73,7 +73,7 @@ func (i MessageIntegrity) AddTo(m *Message) error { length := m.Length // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. m.Length += messageIntegritySize + attributeHeaderSize - m.WriteLength() // writing length to m.Raw + m.WriteLength() // writing length to m.Raw v := newHMAC(i, m.Raw) // calculating HMAC for adjusted m.Raw m.Length = length // changing m.Length back m.Add(AttrMessageIntegrity, v) diff --git a/message.go b/message.go index 4c77853..fdcd421 100644 --- a/message.go +++ b/message.go @@ -3,7 +3,6 @@ package stun import ( "crypto/rand" "encoding/base64" - "encoding/binary" "errors" "fmt" "io" @@ -202,10 +201,6 @@ func (m *Message) WriteHeader() { copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID } -func (m *Message) writeMagicCookie() { - -} - func (m *Message) WriteTransactionID() { copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID } @@ -272,10 +267,10 @@ func (m *Message) Decode() error { return ErrUnexpectedHeaderEOF } var ( - t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes - size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes - cookie = binary.BigEndian.Uint32(buf[4:8]) - fullSize = messageHeaderSize + size + t = bin.Uint16(buf[0:2]) // first 2 bytes + size = int(bin.Uint16(buf[2:4])) // second 2 bytes + cookie = bin.Uint32(buf[4:8]) // last 4 bytes + fullSize = messageHeaderSize + size // len(m.Raw) ) if cookie != magicCookie { msg := fmt.Sprintf( @@ -322,8 +317,8 @@ func (m *Message) Decode() error { offset += attributeHeaderSize if len(b) < aBuffL { // checking size msg := fmt.Sprintf( - "buffer length %d is less than %d (expected value size)", - len(b), aBuffL, + "buffer length %d is less than %d (expected value size for %s)", + len(b), aBuffL, a.Type, ) return newAttrDecodeErr("value", msg) } @@ -344,10 +339,6 @@ func (m *Message) Write(tBuf []byte) (int, error) { return len(tBuf), m.Decode() } -// MaxPacketSize is maximum size of UDP packet that is processable in -// this package for STUN message. -const MaxPacketSize = 2048 - // MessageClass is 8-bit representation of 2-bit class of STUN Message Class. type MessageClass byte @@ -359,6 +350,16 @@ const ( ClassErrorResponse MessageClass = 0x03 // 0b11 ) +// Common STUN message types. +var ( + // Binding request message type. + BindingRequest = NewType(MethodBinding, ClassRequest) + // Binding success response message type + BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) + // Binding error response message type. + BindingError = NewType(MethodBinding, ClassErrorResponse) +) + func (c MessageClass) String() string { switch c { case ClassRequest: @@ -411,16 +412,18 @@ func (m Method) String() string { // MessageType is STUN Message Type Field. type MessageType struct { - Class MessageClass - Method Method + Method Method // e.g. binding + Class MessageClass // e.g. request } +// AddTo sets m type to t. func (t MessageType) AddTo(m *Message) error { m.SetType(t) return nil } -func NewType(class MessageClass, method Method) MessageType { +// NewType returns new message type with provided method and class. +func NewType(method Method, class MessageClass) MessageType { return MessageType{ Method: method, Class: class, @@ -465,7 +468,7 @@ func (t MessageType) Value() uint16 { // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). m = a + (b << methodBShift) + (d << methodDShift) - // C0 is zero bit of C, C1 is fist bit. + // C0 is zero bit of C, C1 is first bit. // C0 = C * 0b01, C1 = (C * 0b10) >> 1 // Ct = C0 << 4 + C1 << 8. // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" diff --git a/nonce.go b/nonce.go deleted file mode 100644 index a5d4899..0000000 --- a/nonce.go +++ /dev/null @@ -1,41 +0,0 @@ -package stun - -import "errors" - -// Nonce represents NONCE attribute. -// -// https://tools.ietf.org/html/rfc5389#section-15.8 -type Nonce []byte - -// NewNonce returns new Nonce from string. -func NewNonce(nonce string) Nonce { - return Nonce(nonce) -} - -func (n Nonce) String() string { - return string(n) -} - -const maxNonceB = 763 - -// ErrNonceTooBig means that NONCE value is bigger that 763 bytes. -var ErrNonceTooBig = errors.New("NONCE value bigger than 763 bytes") - -// AddTo adds NONCE to message. -func (n Nonce) AddTo(m *Message) error { - if len(n) > maxNonceB { - return ErrNonceTooBig - } - m.Add(AttrNonce, n) - return nil -} - -// GetFrom gets NONCE from message. -func (n *Nonce) GetFrom(m *Message) error { - v, err := m.Get(AttrNonce) - if err != nil { - return err - } - *n = v - return nil -} diff --git a/nonce_test.go b/nonce_test.go deleted file mode 100644 index a153406..0000000 --- a/nonce_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package stun - -import ( - "strings" - "testing" -) - -func TestNonce_GetFrom(t *testing.T) { - m := New() - v := "example.org" - m.Add(AttrNonce, []byte(v)) - m.WriteHeader() - - m2 := &Message{ - Raw: make([]byte, 0, 256), - } - var nonce Nonce - if _, err := m2.ReadFrom(m.reader()); err != nil { - t.Error(err) - } - if err := nonce.GetFrom(m); err != nil { - t.Fatal(err) - } - if nonce.String() != v { - t.Errorf("Expected %q, got %q.", v, nonce) - } - - nAttr, ok := m.Attributes.Get(AttrNonce) - if !ok { - t.Error("nonce attribute should be found") - } - s := nAttr.String() - if !strings.HasPrefix(s, "NONCE:") { - t.Error("bad string representation", s) - } -} - -func TestNonce_AddTo_Invalid(t *testing.T) { - m := New() - n := make(Nonce, 1024) - if err := n.AddTo(m); err != ErrNonceTooBig { - t.Errorf("AddTo should return %q, got: %v", ErrNonceTooBig, err) - } - if err := n.GetFrom(m); err != ErrAttributeNotFound { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } -} - -func TestNonce_AddTo(t *testing.T) { - m := New() - n := Nonce("example.org") - if err := n.AddTo(m); err != nil { - t.Error(err) - } - v, err := m.Get(AttrNonce) - if err != nil { - t.Error(err) - } - if string(v) != "example.org" { - t.Errorf("bad nonce %q", v) - } -} diff --git a/padding.go b/padding.go deleted file mode 100644 index fa73bda..0000000 --- a/padding.go +++ /dev/null @@ -1,17 +0,0 @@ -package stun - -// STUN aligns attributes on 32-bit boundaries, attributes whose content -// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of -// padding so that its value contains a multiple of 4 bytes. The -// padding bits are ignored, and may be any value. -// -// https://tools.ietf.org/html/rfc5389#section-15 -const padding = 4 - -func nearestPaddedValueLength(l int) int { - n := padding * (l / padding) - if n < l { - n += padding - } - return n -} diff --git a/padding_test.go b/padding_test.go deleted file mode 100644 index e161bca..0000000 --- a/padding_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package stun - -import "testing" - -func TestPadding(t *testing.T) { - tt := []struct { - in, out int - }{ - {4, 4}, // 0 - {2, 4}, // 1 - {5, 8}, // 2 - {8, 8}, // 3 - {11, 12}, // 4 - {1, 4}, // 5 - {3, 4}, // 6 - {6, 8}, // 7 - {7, 8}, // 8 - {0, 0}, // 9 - {40, 40}, // 10 - } - for i, c := range tt { - if got := nearestPaddedValueLength(c.in); got != c.out { - t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", - i, c.in, got, c.out, - ) - } - } -} diff --git a/realm.go b/realm.go deleted file mode 100644 index ed06f3a..0000000 --- a/realm.go +++ /dev/null @@ -1,43 +0,0 @@ -package stun - -import "errors" - -// NewRealm returns Realm with provided value. -// Must be SASL-prepared. -func NewRealm(realm string) Realm { - // TODO: use sasl - return Realm(realm) -} - -// Realm represents REALM attribute. -// -// https://tools.ietf.org/html/rfc5389#section-15.7 -type Realm []byte - -func (n Realm) String() string { - return string(n) -} - -const maxRealmB = 763 - -// ErrRealmTooBig means that REALM value is bigger that 763 bytes. -var ErrRealmTooBig = errors.New("REALM value bigger than 763 bytes") - -// AddTo adds NONCE to message. -func (n Realm) AddTo(m *Message) error { - if len(n) > maxRealmB { - return ErrRealmTooBig - } - m.Add(AttrRealm, n) - return nil -} - -// GetFrom gets REALM from message. -func (n *Realm) GetFrom(m *Message) error { - v, err := m.Get(AttrRealm) - if err != nil { - return err - } - *n = v - return nil -} diff --git a/realm_test.go b/realm_test.go deleted file mode 100644 index c8232a6..0000000 --- a/realm_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package stun - -import ( - "strings" - "testing" -) - -func TestRealm_GetFrom(t *testing.T) { - m := New() - v := "realm" - m.Add(AttrRealm, []byte(v)) - m.WriteHeader() - - m2 := &Message{ - Raw: make([]byte, 0, 256), - } - r := new(Realm) - if err := r.GetFrom(m2); err != ErrAttributeNotFound { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } - if _, err := m2.ReadFrom(m.reader()); err != nil { - t.Error(err) - } - if err := r.GetFrom(m); err != nil { - t.Fatal(err) - } - if r.String() != v { - t.Errorf("Expected %q, got %q.", v, r) - } - - rAttr, ok := m.Attributes.Get(AttrRealm) - if !ok { - t.Error("realm attribute should be found") - } - s := rAttr.String() - if !strings.HasPrefix(s, "REALM:") { - t.Error("bad string representation", s) - } -} - -func TestRealm_AddTo_Invalid(t *testing.T) { - m := New() - r := make(Realm, 1024) - if err := r.AddTo(m); err != ErrRealmTooBig { - t.Errorf("AddTo should return %q, got: %v", ErrRealmTooBig, err) - } - if err := r.GetFrom(m); err != ErrAttributeNotFound { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } -} diff --git a/software.go b/software.go deleted file mode 100644 index f250fd3..0000000 --- a/software.go +++ /dev/null @@ -1,44 +0,0 @@ -package stun - -import "errors" - -const softwareRawMaxB = 763 - -// ErrSoftwareTooBig means that it is not less than 128 characters -// (which can be as long as 763 bytes). -var ErrSoftwareTooBig = errors.New( - "SOFTWARE attribute bigger than 763 bytes or 128 characters", -) - -// Software is SOFTWARE attribute. -// -// https://tools.ietf.org/html/rfc5389#section-15.10 -type Software []byte - -func (s Software) String() string { - return string(s) -} - -// NewSoftware returns *Software from string. -func NewSoftware(software string) Software { - return Software(software) -} - -// AddTo adds Software attribute to m. -func (s Software) AddTo(m *Message) error { - if len(s) > softwareRawMaxB { - return ErrSoftwareTooBig - } - m.Add(AttrSoftware, m.Raw) - return nil -} - -// GetFrom decodes Software from m. -func (s *Software) GetFrom(m *Message) error { - v, err := m.Get(AttrSoftware) - if err != nil { - return err - } - *s = v - return nil -} diff --git a/software_test.go b/software_test.go deleted file mode 100644 index 2e38ce4..0000000 --- a/software_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package stun - -import ( - "strings" - "testing" -) - -func TestSoftware_GetFrom(t *testing.T) { - m := New() - v := "Client v0.0.1" - m.Add(AttrSoftware, []byte(v)) - m.WriteHeader() - - m2 := &Message{ - Raw: make([]byte, 0, 256), - } - software := new(Software) - if _, err := m2.ReadFrom(m.reader()); err != nil { - t.Error(err) - } - if err := software.GetFrom(m); err != nil { - t.Fatal(err) - } - if software.String() != v { - t.Errorf("Expected %q, got %q.", v, software) - } - - sAttr, ok := m.Attributes.Get(AttrSoftware) - if !ok { - t.Error("sowfware attribute should be found") - } - s := sAttr.String() - if !strings.HasPrefix(s, "SOFTWARE:") { - t.Error("bad string representation", s) - } -} - -func TestSoftware_AddTo_Invalid(t *testing.T) { - m := New() - s := make(Software, 1024) - if err := s.AddTo(m); err != ErrSoftwareTooBig { - t.Errorf("AddTo should return %q, got: %v", ErrSoftwareTooBig, err) - } - if err := s.GetFrom(m); err != ErrAttributeNotFound { - t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) - } -} - -func TestSoftware_AddTo_Regression(t *testing.T) { - // s.AddTo checked len(m.Raw) instead of len(s.Raw). - m := &Message{Raw: make([]byte, 2048)} - s := make(Software, 100) - if err := s.AddTo(m); err != nil { - t.Errorf("AddTo should return , got: %v", err) - } -} diff --git a/stun.go b/stun.go index 18b7c5d..6c72273 100644 --- a/stun.go +++ b/stun.go @@ -1,4 +1,7 @@ // Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389. +// +// The stun package is intended to use by package that implements extension +// to STUN (e.g. TURN) or client/server applications. package stun import "encoding/binary" @@ -6,5 +9,17 @@ import "encoding/binary" // bin is shorthand to binary.BigEndian. var bin = binary.BigEndian -// DefaultPort is IANA assigned Port for "stun" protocol. -const DefaultPort = 3478 +// IANA assigned ports for "stun" protocol/ +const ( + DefaultPort = 3478 + DefaultTLSPort = 5349 +) + +type transactionIDSetter bool + +func (transactionIDSetter) AddTo(m *Message) error { + return m.NewTransactionID() +} + +// TransactionID is Setter for m.TransactionID. +var TransactionID Setter = transactionIDSetter(true) diff --git a/textattrs.go b/textattrs.go new file mode 100644 index 0000000..5f0e228 --- /dev/null +++ b/textattrs.go @@ -0,0 +1,134 @@ +package stun + +// NewUsername returns Username with provided value. +func NewUsername(username string) Username { + return Username(username) +} + +// Username represents USERNAME attribute. +// +// https://tools.ietf.org/html/rfc5389#section-15.3 +type Username []byte + +func (u Username) String() string { + return string(u) +} + +const maxUsernameB = 513 + +// AddTo adds USERNAME attribute to message. +func (u Username) AddTo(m *Message) error { + return TextAttribute(u).AddToAs(m, AttrUsername, maxUsernameB) +} + +// GetFrom gets USERNAME from message. +func (u *Username) GetFrom(m *Message) error { + return (*TextAttribute)(u).GetFromAs(m, AttrUsername) +} + +// NewRealm returns Realm with provided value. +// Must be SASL-prepared. +func NewRealm(realm string) Realm { + // TODO: use sasl + return Realm(realm) +} + +// Realm represents REALM attribute. +// +// https://tools.ietf.org/html/rfc5389#section-15.7 +type Realm []byte + +func (n Realm) String() string { + return string(n) +} + +const maxRealmB = 763 + +// AddTo adds NONCE to message. +func (n Realm) AddTo(m *Message) error { + return TextAttribute(n).AddToAs(m, AttrRealm, maxRealmB) +} + +// GetFrom gets REALM from message. +func (n *Realm) GetFrom(m *Message) error { + return (*TextAttribute)(n).GetFromAs(m, AttrRealm) +} + +const softwareRawMaxB = 763 + +// Software is SOFTWARE attribute. +// +// https://tools.ietf.org/html/rfc5389#section-15.10 +type Software []byte + +func (s Software) String() string { + return string(s) +} + +// NewSoftware returns *Software from string. +func NewSoftware(software string) Software { + return Software(software) +} + +// AddTo adds Software attribute to m. +func (s Software) AddTo(m *Message) error { + return TextAttribute(s).AddToAs(m, AttrSoftware, softwareRawMaxB) +} + +// GetFrom decodes Software from m. +func (s *Software) GetFrom(m *Message) error { + return (*TextAttribute)(s).GetFromAs(m, AttrSoftware) +} + +// Nonce represents NONCE attribute. +// +// https://tools.ietf.org/html/rfc5389#section-15.8 +type Nonce []byte + +// NewNonce returns new Nonce from string. +func NewNonce(nonce string) Nonce { + return Nonce(nonce) +} + +func (n Nonce) String() string { + return string(n) +} + +const maxNonceB = 763 + +// AddTo adds NONCE to message. +func (n Nonce) AddTo(m *Message) error { + return TextAttribute(n).AddToAs(m, AttrNonce, maxNonceB) +} + +// GetFrom gets NONCE from message. +func (n *Nonce) GetFrom(m *Message) error { + return (*TextAttribute)(n).GetFromAs(m, AttrNonce) +} + +// TextAttribute is helper for adding and getting text attributes. +type TextAttribute []byte + +// AddToAs adds attribute with type t to m, checking maximum length. If maxLen +// is less than 0, no check is performed. +func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error { + if maxLen > 0 && len(v) > maxLen { + return &AttrLengthError{ + Max: maxLen, + Got: len(v), + Type: t, + } + } + m.Add(t, v) + return nil +} + +// GetFromAs gets t attribute from m and appends its value to reseted v. +func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error { + a, err := m.Get(t) + if err != nil { + return err + } + *v = append((*v)[:0], a...) + return nil +} diff --git a/textattrs_test.go b/textattrs_test.go new file mode 100644 index 0000000..133a1f2 --- /dev/null +++ b/textattrs_test.go @@ -0,0 +1,249 @@ +package stun + +import ( + "strings" + "testing" +) + +func TestSoftware_GetFrom(t *testing.T) { + m := New() + v := "Client v0.0.1" + m.Add(AttrSoftware, []byte(v)) + m.WriteHeader() + + m2 := &Message{ + Raw: make([]byte, 0, 256), + } + software := new(Software) + if _, err := m2.ReadFrom(m.reader()); err != nil { + t.Error(err) + } + if err := software.GetFrom(m); err != nil { + t.Fatal(err) + } + if software.String() != v { + t.Errorf("Expected %q, got %q.", v, software) + } + + sAttr, ok := m.Attributes.Get(AttrSoftware) + if !ok { + t.Error("sowfware attribute should be found") + } + s := sAttr.String() + if !strings.HasPrefix(s, "SOFTWARE:") { + t.Error("bad string representation", s) + } +} + +func TestSoftware_AddTo_Invalid(t *testing.T) { + m := New() + s := make(Software, 1024) + if err, ok := s.AddTo(m).(*AttrLengthError); !ok { + t.Errorf("AddTo should return *AttrLengthError, got: %v", err) + } + if err := s.GetFrom(m); err != ErrAttributeNotFound { + t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) + } +} + +func TestSoftware_AddTo_Regression(t *testing.T) { + // s.AddTo checked len(m.Raw) instead of len(s.Raw). + m := &Message{Raw: make([]byte, 2048)} + s := make(Software, 100) + if err := s.AddTo(m); err != nil { + t.Errorf("AddTo should return , got: %v", err) + } +} + +func BenchmarkUsername_AddTo(b *testing.B) { + b.ReportAllocs() + m := new(Message) + u := Username("test") + for i := 0; i < b.N; i++ { + if err := u.AddTo(m); err != nil { + b.Fatal(err) + } + m.Reset() + } +} + +func BenchmarkUsername_GetFrom(b *testing.B) { + b.ReportAllocs() + m := new(Message) + Username("test").AddTo(m) + for i := 0; i < b.N; i++ { + var u Username + if err := u.GetFrom(m); err != nil { + b.Fatal(err) + } + } +} + +func TestUsername(t *testing.T) { + username := "username" + u := NewUsername(username) + m := new(Message) + m.WriteHeader() + t.Run("Bad length", func(t *testing.T) { + badU := make(Username, 600) + if err, ok := badU.AddTo(m).(*AttrLengthError); !ok { + t.Errorf("expected length error, got %v", err) + } + }) + t.Run("AddTo", func(t *testing.T) { + if err := u.AddTo(m); err != nil { + t.Error("errored:", err) + } + t.Run("GetFrom", func(t *testing.T) { + got := new(Username) + if err := got.GetFrom(m); err != nil { + t.Error("errored:", err) + } + if got.String() != username { + t.Errorf("expedted: %s, got: %s", username, got) + } + t.Run("Not found", func(t *testing.T) { + m := new(Message) + u := new(Username) + if err := u.GetFrom(m); err != ErrAttributeNotFound { + t.Error("Should error") + } + }) + }) + }) + t.Run("No allocations", func(t *testing.T) { + m := new(Message) + m.WriteHeader() + u := NewUsername("username") + if allocs := testing.AllocsPerRun(10, func() { + if err := u.AddTo(m); err != nil { + t.Error(err) + } + m.Reset() + }); allocs > 0 { + t.Errorf("got %f allocations, zero expected", allocs) + } + }) +} + +func TestRealm_GetFrom(t *testing.T) { + m := New() + v := "realm" + m.Add(AttrRealm, []byte(v)) + m.WriteHeader() + + m2 := &Message{ + Raw: make([]byte, 0, 256), + } + r := new(Realm) + if err := r.GetFrom(m2); err != ErrAttributeNotFound { + t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) + } + if _, err := m2.ReadFrom(m.reader()); err != nil { + t.Error(err) + } + if err := r.GetFrom(m); err != nil { + t.Fatal(err) + } + if r.String() != v { + t.Errorf("Expected %q, got %q.", v, r) + } + + rAttr, ok := m.Attributes.Get(AttrRealm) + if !ok { + t.Error("realm attribute should be found") + } + s := rAttr.String() + if !strings.HasPrefix(s, "REALM:") { + t.Error("bad string representation", s) + } +} + +func TestRealm_AddTo_Invalid(t *testing.T) { + m := New() + r := make(Realm, 1024) + if err, ok := r.AddTo(m).(*AttrLengthError); !ok || err.Type != AttrRealm { + t.Errorf("AddTo should return *AttrLengthError, got: %v", err) + } + if err := r.GetFrom(m); err != ErrAttributeNotFound { + t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) + } +} + +func TestNonce_GetFrom(t *testing.T) { + m := New() + v := "example.org" + m.Add(AttrNonce, []byte(v)) + m.WriteHeader() + + m2 := &Message{ + Raw: make([]byte, 0, 256), + } + var nonce Nonce + if _, err := m2.ReadFrom(m.reader()); err != nil { + t.Error(err) + } + if err := nonce.GetFrom(m); err != nil { + t.Fatal(err) + } + if nonce.String() != v { + t.Errorf("Expected %q, got %q.", v, nonce) + } + + nAttr, ok := m.Attributes.Get(AttrNonce) + if !ok { + t.Error("nonce attribute should be found") + } + s := nAttr.String() + if !strings.HasPrefix(s, "NONCE:") { + t.Error("bad string representation", s) + } +} + +func TestNonce_AddTo_Invalid(t *testing.T) { + m := New() + n := make(Nonce, 1024) + if err, ok := n.AddTo(m).(*AttrLengthError); !ok || err.Type != AttrNonce { + t.Errorf("AddTo should return *AttrLengthError, got: %v", err) + } + if err := n.GetFrom(m); err != ErrAttributeNotFound { + t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) + } +} + +func TestNonce_AddTo(t *testing.T) { + m := New() + n := Nonce("example.org") + if err := n.AddTo(m); err != nil { + t.Error(err) + } + v, err := m.Get(AttrNonce) + if err != nil { + t.Error(err) + } + if string(v) != "example.org" { + t.Errorf("bad nonce %q", v) + } +} + +func BenchmarkNonce_AddTo(b *testing.B) { + b.ReportAllocs() + m := New() + n := NewNonce("nonce") + for i := 0; i < b.N; i++ { + if err := n.AddTo(m); err != nil { + b.Fatal(err) + } + m.Reset() + } +} + +func BenchmarkNonce_GetFrom(b *testing.B) { + b.ReportAllocs() + m := New() + n := NewNonce("nonce") + n.AddTo(m) + for i := 0; i < b.N; i++ { + n.GetFrom(m) + } +} diff --git a/transaction.go b/transaction.go deleted file mode 100644 index c9a9e64..0000000 --- a/transaction.go +++ /dev/null @@ -1,10 +0,0 @@ -package stun - -type transactionIDSetter bool - -func (transactionIDSetter) AddTo(m *Message) error { - return m.NewTransactionID() -} - -// TransactionID is Setter for m.TransactionID. -var TransactionID Setter = transactionIDSetter(true) diff --git a/username.go b/username.go deleted file mode 100644 index 7e04757..0000000 --- a/username.go +++ /dev/null @@ -1,41 +0,0 @@ -package stun - -import "errors" - -// NewUsername returns Username with provided value. -func NewUsername(username string) Username { - return Username(username) -} - -// Username represents USERNAME attribute. -// -// https://tools.ietf.org/html/rfc5389#section-15.3 -type Username []byte - -func (u Username) String() string { - return string(u) -} - -const maxUsernameB = 513 - -// ErrUsernameTooBig means that USERNAME value is bigger that 513 bytes. -var ErrUsernameTooBig = errors.New("USERNAME value bigger than 513 bytes") - -// AddTo adds USERNAME attribute to message. -func (u Username) AddTo(m *Message) error { - if len(u) > maxUsernameB { - return ErrUsernameTooBig - } - m.Add(AttrUsername, u) - return nil -} - -// GetFrom gets USERNAME from message. -func (u *Username) GetFrom(m *Message) error { - v, err := m.Get(AttrUsername) - if err != nil { - return err - } - *u = v - return nil -} diff --git a/username_test.go b/username_test.go deleted file mode 100644 index eb51461..0000000 --- a/username_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package stun - -import ( - "testing" -) - -func BenchmarkUsername_AddTo(b *testing.B) { - b.ReportAllocs() - m := new(Message) - u := Username("test") - for i := 0; i < b.N; i++ { - if err := u.AddTo(m); err != nil { - b.Fatal(err) - } - m.Reset() - } -} - -func BenchmarkUsername_GetFrom(b *testing.B) { - b.ReportAllocs() - m := new(Message) - Username("test").AddTo(m) - for i := 0; i < b.N; i++ { - var u Username - if err := u.GetFrom(m); err != nil { - b.Fatal(err) - } - } -} - -func TestUsername(t *testing.T) { - username := "username" - u := NewUsername(username) - m := new(Message) - m.WriteHeader() - t.Run("Bad length", func(t *testing.T) { - badU := make(Username, 600) - if err := badU.AddTo(m); err != ErrUsernameTooBig { - t.Errorf("expected %s, got %v", ErrUsernameTooBig, err) - } - }) - t.Run("AddTo", func(t *testing.T) { - if err := u.AddTo(m); err != nil { - t.Error("errored:", err) - } - t.Run("GetFrom", func(t *testing.T) { - got := new(Username) - if err := got.GetFrom(m); err != nil { - t.Error("errored:", err) - } - if got.String() != username { - t.Errorf("expedted: %s, got: %s", username, got) - } - t.Run("Not found", func(t *testing.T) { - m := new(Message) - u := new(Username) - if err := u.GetFrom(m); err != ErrAttributeNotFound { - t.Error("Should error") - } - }) - }) - }) - t.Run("No allocations", func(t *testing.T) { - m := new(Message) - m.WriteHeader() - u := NewUsername("username") - if allocs := testing.AllocsPerRun(10, func() { - if err := u.AddTo(m); err != nil { - t.Error(err) - } - m.Reset() - }); allocs > 0 { - t.Errorf("got %f allocations, zero expected", allocs) - } - }) -}