all: refactor text attributes and cleanup

This commit is contained in:
Aleksandr Razumov 2017-02-22 19:22:02 +03:00
parent a21c83b6ca
commit 542983501d
21 changed files with 564 additions and 514 deletions

View File

@ -161,3 +161,32 @@ func (m *Message) Get(t AttrType) ([]byte, error) {
} }
return v.Value, nil return v.Value, nil
} }
// AttrLengthError occurs when len(v) > Max.
type AttrLengthError struct {
Type AttrType
Max int
Got int
}
func (e AttrLengthError) Error() string {
return fmt.Sprintf("Length of %s attribute %d exceeds maximum %d",
e.Type, e.Got, e.Max,
)
}
// STUN aligns attributes on 32-bit boundaries, attributes whose content
// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of
// padding so that its value contains a multiple of 4 bytes. The
// padding bits are ignored, and may be any value.
//
// https://tools.ietf.org/html/rfc5389#section-15
const padding = 4
func nearestPaddedValueLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}

View File

@ -34,3 +34,28 @@ func TestMessage_GetNoAllocs(t *testing.T) {
} }
}) })
} }
func TestPadding(t *testing.T) {
tt := []struct {
in, out int
}{
{4, 4}, // 0
{2, 4}, // 1
{5, 8}, // 2
{8, 8}, // 3
{11, 12}, // 4
{1, 4}, // 5
{3, 4}, // 6
{6, 8}, // 7
{7, 8}, // 8
{0, 0}, // 9
{40, 40}, // 10
}
for i, c := range tt {
if got := nearestPaddedValueLength(c.in); got != c.out {
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
i, c.in, got, c.out,
)
}
}
}

View File

@ -68,7 +68,10 @@ func (FingerprintAttr) Check(m *Message) error {
return err return err
} }
if len(b) != fingerprintSize { if len(b) != fingerprintSize {
return newDecodeErr("message", "fingerprint", "bad length") return newDecodeErr("message",
"fingerprint",
"bad length",
)
} }
val := bin.Uint32(b) val := bin.Uint32(b)
attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize) attrStart := len(m.Raw) - (fingerprintSize + attributeHeaderSize)

View File

@ -1,22 +1,38 @@
package stun package stun
// Setter sets *Message attribute. // Interfaces that are implemented by message attributes, shorthands for them,
type Setter interface { // or helpers for message fields as type or transaction id.
AddTo(m *Message) error type (
} // Setter sets *Message attribute.
Setter interface {
AddTo(m *Message) error
}
// Getter parses attribute from *Message.
Getter interface {
GetFrom(m *Message) error
}
// Checker checks *Message attribute.
Checker interface {
Check(m *Message) error
}
)
// Getter decodes *Message attribute. // Build resets message and applies setters to it in batch, returning on
type Getter interface { // first error. To prevent allocations, pass pointers to values.
GetFrom(m *Message) error //
} // Example:
// var (
// Checker checks *Message attribute. // t = BindingRequest
type Checker interface { // username = NewUsername("username")
Check(m *Message) error // nonce = NewNonce("nonce")
} // realm = NewRealm("example.org")
// )
// Build applies setters to message. // m := new(Message)
func (m *Message) Build(setters... Setter) error { // m.Build(t, username, nonce, realm) // 4 allocations
// m.Build(&t, &username, &nonce, &realm) // 0 allocations
//
// See BenchmarkBuildOverhead.
func (m *Message) Build(setters ...Setter) error {
m.Reset() m.Reset()
m.WriteHeader() m.WriteHeader()
for _, s := range setters { for _, s := range setters {
@ -27,7 +43,8 @@ func (m *Message) Build(setters... Setter) error {
return nil return nil
} }
func (m *Message) Check(checkers... Checker) error { // Checker applies chereks to message in batch, returning on first error.
func (m *Message) Check(checkers ...Checker) error {
for _, c := range checkers { for _, c := range checkers {
if err := c.Check(m); err != nil { if err := c.Check(m); err != nil {
return err return err
@ -36,8 +53,18 @@ func (m *Message) Check(checkers... Checker) error {
return nil return nil
} }
// Parse applies getters to message in batch, returning on first error.
func (m *Message) Parse(getters ...Getter) error {
for _, c := range getters {
if err := c.GetFrom(m); err != nil {
return err
}
}
return nil
}
// Build wraps Message.Build method. // Build wraps Message.Build method.
func Build(setters... Setter) (*Message, error) { func Build(setters ...Setter) (*Message, error) {
m := new(Message) m := new(Message)
return m, m.Build(setters...) return m, m.Build(setters...)
} }

View File

@ -2,14 +2,48 @@ package stun
import "testing" import "testing"
func BenchmarkBuildOverhead(b *testing.B) {
var (
t = BindingRequest
username = NewUsername("username")
nonce = NewNonce("nonce")
realm = NewRealm("example.org")
)
b.Run("Build", func(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
m.Build(&t, &username, &nonce, &realm, &Fingerprint)
}
})
b.Run("BuildNonPointer", func(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
m.Build(t, username, nonce, realm, Fingerprint)
}
})
b.Run("Raw", func(b *testing.B) {
b.ReportAllocs()
m := new(Message)
for i := 0; i < b.N; i++ {
m.Reset()
m.WriteHeader()
m.SetType(t)
username.AddTo(m)
nonce.AddTo(m)
realm.AddTo(m)
Fingerprint.AddTo(m)
}
})
}
func TestMessage_Apply(t *testing.T) { func TestMessage_Apply(t *testing.T) {
var ( var (
integrity = NewShortTermIntegrity("password") integrity = NewShortTermIntegrity("password")
decoded = new(Message) decoded = new(Message)
) )
m, err := Build( m, err := Build(BindingRequest, TransactionID,
NewType(ClassRequest, MethodBinding),
TransactionID,
NewUsername("username"), NewUsername("username"),
NewNonce("nonce"), NewNonce("nonce"),
NewRealm("example.org"), NewRealm("example.org"),
@ -28,7 +62,6 @@ func TestMessage_Apply(t *testing.T) {
if !decoded.Equal(m) { if !decoded.Equal(m) {
t.Error("not equal") t.Error("not equal")
} }
if err := integrity.Check(decoded); err != nil { if err := integrity.Check(decoded); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -73,7 +73,7 @@ func (i MessageIntegrity) AddTo(m *Message) error {
length := m.Length length := m.Length
// Adjusting m.Length to contain MESSAGE-INTEGRITY TLV. // Adjusting m.Length to contain MESSAGE-INTEGRITY TLV.
m.Length += messageIntegritySize + attributeHeaderSize m.Length += messageIntegritySize + attributeHeaderSize
m.WriteLength() // writing length to m.Raw m.WriteLength() // writing length to m.Raw
v := newHMAC(i, m.Raw) // calculating HMAC for adjusted m.Raw v := newHMAC(i, m.Raw) // calculating HMAC for adjusted m.Raw
m.Length = length // changing m.Length back m.Length = length // changing m.Length back
m.Add(AttrMessageIntegrity, v) m.Add(AttrMessageIntegrity, v)

View File

@ -3,7 +3,6 @@ package stun
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -202,10 +201,6 @@ func (m *Message) WriteHeader() {
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
} }
func (m *Message) writeMagicCookie() {
}
func (m *Message) WriteTransactionID() { func (m *Message) WriteTransactionID() {
copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
} }
@ -272,10 +267,10 @@ func (m *Message) Decode() error {
return ErrUnexpectedHeaderEOF return ErrUnexpectedHeaderEOF
} }
var ( var (
t = binary.BigEndian.Uint16(buf[0:2]) // first 2 bytes t = bin.Uint16(buf[0:2]) // first 2 bytes
size = int(binary.BigEndian.Uint16(buf[2:4])) // second 2 bytes size = int(bin.Uint16(buf[2:4])) // second 2 bytes
cookie = binary.BigEndian.Uint32(buf[4:8]) cookie = bin.Uint32(buf[4:8]) // last 4 bytes
fullSize = messageHeaderSize + size fullSize = messageHeaderSize + size // len(m.Raw)
) )
if cookie != magicCookie { if cookie != magicCookie {
msg := fmt.Sprintf( msg := fmt.Sprintf(
@ -322,8 +317,8 @@ func (m *Message) Decode() error {
offset += attributeHeaderSize offset += attributeHeaderSize
if len(b) < aBuffL { // checking size if len(b) < aBuffL { // checking size
msg := fmt.Sprintf( msg := fmt.Sprintf(
"buffer length %d is less than %d (expected value size)", "buffer length %d is less than %d (expected value size for %s)",
len(b), aBuffL, len(b), aBuffL, a.Type,
) )
return newAttrDecodeErr("value", msg) return newAttrDecodeErr("value", msg)
} }
@ -344,10 +339,6 @@ func (m *Message) Write(tBuf []byte) (int, error) {
return len(tBuf), m.Decode() return len(tBuf), m.Decode()
} }
// MaxPacketSize is maximum size of UDP packet that is processable in
// this package for STUN message.
const MaxPacketSize = 2048
// MessageClass is 8-bit representation of 2-bit class of STUN Message Class. // MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
type MessageClass byte type MessageClass byte
@ -359,6 +350,16 @@ const (
ClassErrorResponse MessageClass = 0x03 // 0b11 ClassErrorResponse MessageClass = 0x03 // 0b11
) )
// Common STUN message types.
var (
// Binding request message type.
BindingRequest = NewType(MethodBinding, ClassRequest)
// Binding success response message type
BindingSuccess = NewType(MethodBinding, ClassSuccessResponse)
// Binding error response message type.
BindingError = NewType(MethodBinding, ClassErrorResponse)
)
func (c MessageClass) String() string { func (c MessageClass) String() string {
switch c { switch c {
case ClassRequest: case ClassRequest:
@ -411,16 +412,18 @@ func (m Method) String() string {
// MessageType is STUN Message Type Field. // MessageType is STUN Message Type Field.
type MessageType struct { type MessageType struct {
Class MessageClass Method Method // e.g. binding
Method Method Class MessageClass // e.g. request
} }
// AddTo sets m type to t.
func (t MessageType) AddTo(m *Message) error { func (t MessageType) AddTo(m *Message) error {
m.SetType(t) m.SetType(t)
return nil return nil
} }
func NewType(class MessageClass, method Method) MessageType { // NewType returns new message type with provided method and class.
func NewType(method Method, class MessageClass) MessageType {
return MessageType{ return MessageType{
Method: method, Method: method,
Class: class, Class: class,
@ -465,7 +468,7 @@ func (t MessageType) Value() uint16 {
// Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit). // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
m = a + (b << methodBShift) + (d << methodDShift) m = a + (b << methodBShift) + (d << methodDShift)
// C0 is zero bit of C, C1 is fist bit. // C0 is zero bit of C, C1 is first bit.
// C0 = C * 0b01, C1 = (C * 0b10) >> 1 // C0 = C * 0b01, C1 = (C * 0b10) >> 1
// Ct = C0 << 4 + C1 << 8. // Ct = C0 << 4 + C1 << 8.
// Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7" // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"

View File

@ -1,41 +0,0 @@
package stun
import "errors"
// Nonce represents NONCE attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.8
type Nonce []byte
// NewNonce returns new Nonce from string.
func NewNonce(nonce string) Nonce {
return Nonce(nonce)
}
func (n Nonce) String() string {
return string(n)
}
const maxNonceB = 763
// ErrNonceTooBig means that NONCE value is bigger that 763 bytes.
var ErrNonceTooBig = errors.New("NONCE value bigger than 763 bytes")
// AddTo adds NONCE to message.
func (n Nonce) AddTo(m *Message) error {
if len(n) > maxNonceB {
return ErrNonceTooBig
}
m.Add(AttrNonce, n)
return nil
}
// GetFrom gets NONCE from message.
func (n *Nonce) GetFrom(m *Message) error {
v, err := m.Get(AttrNonce)
if err != nil {
return err
}
*n = v
return nil
}

View File

@ -1,62 +0,0 @@
package stun
import (
"strings"
"testing"
)
func TestNonce_GetFrom(t *testing.T) {
m := New()
v := "example.org"
m.Add(AttrNonce, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
var nonce Nonce
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := nonce.GetFrom(m); err != nil {
t.Fatal(err)
}
if nonce.String() != v {
t.Errorf("Expected %q, got %q.", v, nonce)
}
nAttr, ok := m.Attributes.Get(AttrNonce)
if !ok {
t.Error("nonce attribute should be found")
}
s := nAttr.String()
if !strings.HasPrefix(s, "NONCE:") {
t.Error("bad string representation", s)
}
}
func TestNonce_AddTo_Invalid(t *testing.T) {
m := New()
n := make(Nonce, 1024)
if err := n.AddTo(m); err != ErrNonceTooBig {
t.Errorf("AddTo should return %q, got: %v", ErrNonceTooBig, err)
}
if err := n.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}
func TestNonce_AddTo(t *testing.T) {
m := New()
n := Nonce("example.org")
if err := n.AddTo(m); err != nil {
t.Error(err)
}
v, err := m.Get(AttrNonce)
if err != nil {
t.Error(err)
}
if string(v) != "example.org" {
t.Errorf("bad nonce %q", v)
}
}

View File

@ -1,17 +0,0 @@
package stun
// STUN aligns attributes on 32-bit boundaries, attributes whose content
// is not a multiple of 4 bytes are padded with 1, 2, or 3 bytes of
// padding so that its value contains a multiple of 4 bytes. The
// padding bits are ignored, and may be any value.
//
// https://tools.ietf.org/html/rfc5389#section-15
const padding = 4
func nearestPaddedValueLength(l int) int {
n := padding * (l / padding)
if n < l {
n += padding
}
return n
}

View File

@ -1,28 +0,0 @@
package stun
import "testing"
func TestPadding(t *testing.T) {
tt := []struct {
in, out int
}{
{4, 4}, // 0
{2, 4}, // 1
{5, 8}, // 2
{8, 8}, // 3
{11, 12}, // 4
{1, 4}, // 5
{3, 4}, // 6
{6, 8}, // 7
{7, 8}, // 8
{0, 0}, // 9
{40, 40}, // 10
}
for i, c := range tt {
if got := nearestPaddedValueLength(c.in); got != c.out {
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
i, c.in, got, c.out,
)
}
}
}

View File

@ -1,43 +0,0 @@
package stun
import "errors"
// NewRealm returns Realm with provided value.
// Must be SASL-prepared.
func NewRealm(realm string) Realm {
// TODO: use sasl
return Realm(realm)
}
// Realm represents REALM attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.7
type Realm []byte
func (n Realm) String() string {
return string(n)
}
const maxRealmB = 763
// ErrRealmTooBig means that REALM value is bigger that 763 bytes.
var ErrRealmTooBig = errors.New("REALM value bigger than 763 bytes")
// AddTo adds NONCE to message.
func (n Realm) AddTo(m *Message) error {
if len(n) > maxRealmB {
return ErrRealmTooBig
}
m.Add(AttrRealm, n)
return nil
}
// GetFrom gets REALM from message.
func (n *Realm) GetFrom(m *Message) error {
v, err := m.Get(AttrRealm)
if err != nil {
return err
}
*n = v
return nil
}

View File

@ -1,50 +0,0 @@
package stun
import (
"strings"
"testing"
)
func TestRealm_GetFrom(t *testing.T) {
m := New()
v := "realm"
m.Add(AttrRealm, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
r := new(Realm)
if err := r.GetFrom(m2); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := r.GetFrom(m); err != nil {
t.Fatal(err)
}
if r.String() != v {
t.Errorf("Expected %q, got %q.", v, r)
}
rAttr, ok := m.Attributes.Get(AttrRealm)
if !ok {
t.Error("realm attribute should be found")
}
s := rAttr.String()
if !strings.HasPrefix(s, "REALM:") {
t.Error("bad string representation", s)
}
}
func TestRealm_AddTo_Invalid(t *testing.T) {
m := New()
r := make(Realm, 1024)
if err := r.AddTo(m); err != ErrRealmTooBig {
t.Errorf("AddTo should return %q, got: %v", ErrRealmTooBig, err)
}
if err := r.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}

View File

@ -1,44 +0,0 @@
package stun
import "errors"
const softwareRawMaxB = 763
// ErrSoftwareTooBig means that it is not less than 128 characters
// (which can be as long as 763 bytes).
var ErrSoftwareTooBig = errors.New(
"SOFTWARE attribute bigger than 763 bytes or 128 characters",
)
// Software is SOFTWARE attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.10
type Software []byte
func (s Software) String() string {
return string(s)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) Software {
return Software(software)
}
// AddTo adds Software attribute to m.
func (s Software) AddTo(m *Message) error {
if len(s) > softwareRawMaxB {
return ErrSoftwareTooBig
}
m.Add(AttrSoftware, m.Raw)
return nil
}
// GetFrom decodes Software from m.
func (s *Software) GetFrom(m *Message) error {
v, err := m.Get(AttrSoftware)
if err != nil {
return err
}
*s = v
return nil
}

View File

@ -1,56 +0,0 @@
package stun
import (
"strings"
"testing"
)
func TestSoftware_GetFrom(t *testing.T) {
m := New()
v := "Client v0.0.1"
m.Add(AttrSoftware, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := software.GetFrom(m); err != nil {
t.Fatal(err)
}
if software.String() != v {
t.Errorf("Expected %q, got %q.", v, software)
}
sAttr, ok := m.Attributes.Get(AttrSoftware)
if !ok {
t.Error("sowfware attribute should be found")
}
s := sAttr.String()
if !strings.HasPrefix(s, "SOFTWARE:") {
t.Error("bad string representation", s)
}
}
func TestSoftware_AddTo_Invalid(t *testing.T) {
m := New()
s := make(Software, 1024)
if err := s.AddTo(m); err != ErrSoftwareTooBig {
t.Errorf("AddTo should return %q, got: %v", ErrSoftwareTooBig, err)
}
if err := s.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}
func TestSoftware_AddTo_Regression(t *testing.T) {
// s.AddTo checked len(m.Raw) instead of len(s.Raw).
m := &Message{Raw: make([]byte, 2048)}
s := make(Software, 100)
if err := s.AddTo(m); err != nil {
t.Errorf("AddTo should return <nil>, got: %v", err)
}
}

19
stun.go
View File

@ -1,4 +1,7 @@
// Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389. // Package stun implements Session Traversal Utilities for NAT (STUN) RFC 5389.
//
// The stun package is intended to use by package that implements extension
// to STUN (e.g. TURN) or client/server applications.
package stun package stun
import "encoding/binary" import "encoding/binary"
@ -6,5 +9,17 @@ import "encoding/binary"
// bin is shorthand to binary.BigEndian. // bin is shorthand to binary.BigEndian.
var bin = binary.BigEndian var bin = binary.BigEndian
// DefaultPort is IANA assigned Port for "stun" protocol. // IANA assigned ports for "stun" protocol/
const DefaultPort = 3478 const (
DefaultPort = 3478
DefaultTLSPort = 5349
)
type transactionIDSetter bool
func (transactionIDSetter) AddTo(m *Message) error {
return m.NewTransactionID()
}
// TransactionID is Setter for m.TransactionID.
var TransactionID Setter = transactionIDSetter(true)

134
textattrs.go Normal file
View File

@ -0,0 +1,134 @@
package stun
// NewUsername returns Username with provided value.
func NewUsername(username string) Username {
return Username(username)
}
// Username represents USERNAME attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.3
type Username []byte
func (u Username) String() string {
return string(u)
}
const maxUsernameB = 513
// AddTo adds USERNAME attribute to message.
func (u Username) AddTo(m *Message) error {
return TextAttribute(u).AddToAs(m, AttrUsername, maxUsernameB)
}
// GetFrom gets USERNAME from message.
func (u *Username) GetFrom(m *Message) error {
return (*TextAttribute)(u).GetFromAs(m, AttrUsername)
}
// NewRealm returns Realm with provided value.
// Must be SASL-prepared.
func NewRealm(realm string) Realm {
// TODO: use sasl
return Realm(realm)
}
// Realm represents REALM attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.7
type Realm []byte
func (n Realm) String() string {
return string(n)
}
const maxRealmB = 763
// AddTo adds NONCE to message.
func (n Realm) AddTo(m *Message) error {
return TextAttribute(n).AddToAs(m, AttrRealm, maxRealmB)
}
// GetFrom gets REALM from message.
func (n *Realm) GetFrom(m *Message) error {
return (*TextAttribute)(n).GetFromAs(m, AttrRealm)
}
const softwareRawMaxB = 763
// Software is SOFTWARE attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.10
type Software []byte
func (s Software) String() string {
return string(s)
}
// NewSoftware returns *Software from string.
func NewSoftware(software string) Software {
return Software(software)
}
// AddTo adds Software attribute to m.
func (s Software) AddTo(m *Message) error {
return TextAttribute(s).AddToAs(m, AttrSoftware, softwareRawMaxB)
}
// GetFrom decodes Software from m.
func (s *Software) GetFrom(m *Message) error {
return (*TextAttribute)(s).GetFromAs(m, AttrSoftware)
}
// Nonce represents NONCE attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.8
type Nonce []byte
// NewNonce returns new Nonce from string.
func NewNonce(nonce string) Nonce {
return Nonce(nonce)
}
func (n Nonce) String() string {
return string(n)
}
const maxNonceB = 763
// AddTo adds NONCE to message.
func (n Nonce) AddTo(m *Message) error {
return TextAttribute(n).AddToAs(m, AttrNonce, maxNonceB)
}
// GetFrom gets NONCE from message.
func (n *Nonce) GetFrom(m *Message) error {
return (*TextAttribute)(n).GetFromAs(m, AttrNonce)
}
// TextAttribute is helper for adding and getting text attributes.
type TextAttribute []byte
// AddToAs adds attribute with type t to m, checking maximum length. If maxLen
// is less than 0, no check is performed.
func (v TextAttribute) AddToAs(m *Message, t AttrType, maxLen int) error {
if maxLen > 0 && len(v) > maxLen {
return &AttrLengthError{
Max: maxLen,
Got: len(v),
Type: t,
}
}
m.Add(t, v)
return nil
}
// GetFromAs gets t attribute from m and appends its value to reseted v.
func (v *TextAttribute) GetFromAs(m *Message, t AttrType) error {
a, err := m.Get(t)
if err != nil {
return err
}
*v = append((*v)[:0], a...)
return nil
}

249
textattrs_test.go Normal file
View File

@ -0,0 +1,249 @@
package stun
import (
"strings"
"testing"
)
func TestSoftware_GetFrom(t *testing.T) {
m := New()
v := "Client v0.0.1"
m.Add(AttrSoftware, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
software := new(Software)
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := software.GetFrom(m); err != nil {
t.Fatal(err)
}
if software.String() != v {
t.Errorf("Expected %q, got %q.", v, software)
}
sAttr, ok := m.Attributes.Get(AttrSoftware)
if !ok {
t.Error("sowfware attribute should be found")
}
s := sAttr.String()
if !strings.HasPrefix(s, "SOFTWARE:") {
t.Error("bad string representation", s)
}
}
func TestSoftware_AddTo_Invalid(t *testing.T) {
m := New()
s := make(Software, 1024)
if err, ok := s.AddTo(m).(*AttrLengthError); !ok {
t.Errorf("AddTo should return *AttrLengthError, got: %v", err)
}
if err := s.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}
func TestSoftware_AddTo_Regression(t *testing.T) {
// s.AddTo checked len(m.Raw) instead of len(s.Raw).
m := &Message{Raw: make([]byte, 2048)}
s := make(Software, 100)
if err := s.AddTo(m); err != nil {
t.Errorf("AddTo should return <nil>, got: %v", err)
}
}
func BenchmarkUsername_AddTo(b *testing.B) {
b.ReportAllocs()
m := new(Message)
u := Username("test")
for i := 0; i < b.N; i++ {
if err := u.AddTo(m); err != nil {
b.Fatal(err)
}
m.Reset()
}
}
func BenchmarkUsername_GetFrom(b *testing.B) {
b.ReportAllocs()
m := new(Message)
Username("test").AddTo(m)
for i := 0; i < b.N; i++ {
var u Username
if err := u.GetFrom(m); err != nil {
b.Fatal(err)
}
}
}
func TestUsername(t *testing.T) {
username := "username"
u := NewUsername(username)
m := new(Message)
m.WriteHeader()
t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600)
if err, ok := badU.AddTo(m).(*AttrLengthError); !ok {
t.Errorf("expected length error, got %v", err)
}
})
t.Run("AddTo", func(t *testing.T) {
if err := u.AddTo(m); err != nil {
t.Error("errored:", err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(Username)
if err := got.GetFrom(m); err != nil {
t.Error("errored:", err)
}
if got.String() != username {
t.Errorf("expedted: %s, got: %s", username, got)
}
t.Run("Not found", func(t *testing.T) {
m := new(Message)
u := new(Username)
if err := u.GetFrom(m); err != ErrAttributeNotFound {
t.Error("Should error")
}
})
})
})
t.Run("No allocations", func(t *testing.T) {
m := new(Message)
m.WriteHeader()
u := NewUsername("username")
if allocs := testing.AllocsPerRun(10, func() {
if err := u.AddTo(m); err != nil {
t.Error(err)
}
m.Reset()
}); allocs > 0 {
t.Errorf("got %f allocations, zero expected", allocs)
}
})
}
func TestRealm_GetFrom(t *testing.T) {
m := New()
v := "realm"
m.Add(AttrRealm, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
r := new(Realm)
if err := r.GetFrom(m2); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := r.GetFrom(m); err != nil {
t.Fatal(err)
}
if r.String() != v {
t.Errorf("Expected %q, got %q.", v, r)
}
rAttr, ok := m.Attributes.Get(AttrRealm)
if !ok {
t.Error("realm attribute should be found")
}
s := rAttr.String()
if !strings.HasPrefix(s, "REALM:") {
t.Error("bad string representation", s)
}
}
func TestRealm_AddTo_Invalid(t *testing.T) {
m := New()
r := make(Realm, 1024)
if err, ok := r.AddTo(m).(*AttrLengthError); !ok || err.Type != AttrRealm {
t.Errorf("AddTo should return *AttrLengthError, got: %v", err)
}
if err := r.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}
func TestNonce_GetFrom(t *testing.T) {
m := New()
v := "example.org"
m.Add(AttrNonce, []byte(v))
m.WriteHeader()
m2 := &Message{
Raw: make([]byte, 0, 256),
}
var nonce Nonce
if _, err := m2.ReadFrom(m.reader()); err != nil {
t.Error(err)
}
if err := nonce.GetFrom(m); err != nil {
t.Fatal(err)
}
if nonce.String() != v {
t.Errorf("Expected %q, got %q.", v, nonce)
}
nAttr, ok := m.Attributes.Get(AttrNonce)
if !ok {
t.Error("nonce attribute should be found")
}
s := nAttr.String()
if !strings.HasPrefix(s, "NONCE:") {
t.Error("bad string representation", s)
}
}
func TestNonce_AddTo_Invalid(t *testing.T) {
m := New()
n := make(Nonce, 1024)
if err, ok := n.AddTo(m).(*AttrLengthError); !ok || err.Type != AttrNonce {
t.Errorf("AddTo should return *AttrLengthError, got: %v", err)
}
if err := n.GetFrom(m); err != ErrAttributeNotFound {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
}
func TestNonce_AddTo(t *testing.T) {
m := New()
n := Nonce("example.org")
if err := n.AddTo(m); err != nil {
t.Error(err)
}
v, err := m.Get(AttrNonce)
if err != nil {
t.Error(err)
}
if string(v) != "example.org" {
t.Errorf("bad nonce %q", v)
}
}
func BenchmarkNonce_AddTo(b *testing.B) {
b.ReportAllocs()
m := New()
n := NewNonce("nonce")
for i := 0; i < b.N; i++ {
if err := n.AddTo(m); err != nil {
b.Fatal(err)
}
m.Reset()
}
}
func BenchmarkNonce_GetFrom(b *testing.B) {
b.ReportAllocs()
m := New()
n := NewNonce("nonce")
n.AddTo(m)
for i := 0; i < b.N; i++ {
n.GetFrom(m)
}
}

View File

@ -1,10 +0,0 @@
package stun
type transactionIDSetter bool
func (transactionIDSetter) AddTo(m *Message) error {
return m.NewTransactionID()
}
// TransactionID is Setter for m.TransactionID.
var TransactionID Setter = transactionIDSetter(true)

View File

@ -1,41 +0,0 @@
package stun
import "errors"
// NewUsername returns Username with provided value.
func NewUsername(username string) Username {
return Username(username)
}
// Username represents USERNAME attribute.
//
// https://tools.ietf.org/html/rfc5389#section-15.3
type Username []byte
func (u Username) String() string {
return string(u)
}
const maxUsernameB = 513
// ErrUsernameTooBig means that USERNAME value is bigger that 513 bytes.
var ErrUsernameTooBig = errors.New("USERNAME value bigger than 513 bytes")
// AddTo adds USERNAME attribute to message.
func (u Username) AddTo(m *Message) error {
if len(u) > maxUsernameB {
return ErrUsernameTooBig
}
m.Add(AttrUsername, u)
return nil
}
// GetFrom gets USERNAME from message.
func (u *Username) GetFrom(m *Message) error {
v, err := m.Get(AttrUsername)
if err != nil {
return err
}
*u = v
return nil
}

View File

@ -1,76 +0,0 @@
package stun
import (
"testing"
)
func BenchmarkUsername_AddTo(b *testing.B) {
b.ReportAllocs()
m := new(Message)
u := Username("test")
for i := 0; i < b.N; i++ {
if err := u.AddTo(m); err != nil {
b.Fatal(err)
}
m.Reset()
}
}
func BenchmarkUsername_GetFrom(b *testing.B) {
b.ReportAllocs()
m := new(Message)
Username("test").AddTo(m)
for i := 0; i < b.N; i++ {
var u Username
if err := u.GetFrom(m); err != nil {
b.Fatal(err)
}
}
}
func TestUsername(t *testing.T) {
username := "username"
u := NewUsername(username)
m := new(Message)
m.WriteHeader()
t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600)
if err := badU.AddTo(m); err != ErrUsernameTooBig {
t.Errorf("expected %s, got %v", ErrUsernameTooBig, err)
}
})
t.Run("AddTo", func(t *testing.T) {
if err := u.AddTo(m); err != nil {
t.Error("errored:", err)
}
t.Run("GetFrom", func(t *testing.T) {
got := new(Username)
if err := got.GetFrom(m); err != nil {
t.Error("errored:", err)
}
if got.String() != username {
t.Errorf("expedted: %s, got: %s", username, got)
}
t.Run("Not found", func(t *testing.T) {
m := new(Message)
u := new(Username)
if err := u.GetFrom(m); err != ErrAttributeNotFound {
t.Error("Should error")
}
})
})
})
t.Run("No allocations", func(t *testing.T) {
m := new(Message)
m.WriteHeader()
u := NewUsername("username")
if allocs := testing.AllocsPerRun(10, func() {
if err := u.AddTo(m); err != nil {
t.Error(err)
}
m.Reset()
}); allocs > 0 {
t.Errorf("got %f allocations, zero expected", allocs)
}
})
}