-add 新增rtmp推流实现

This commit is contained in:
renyaqi
2022-11-07 17:27:29 +08:00
commit de8c7d67fc
63 changed files with 10425 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
go.sum
+3
View File
@@ -0,0 +1,3 @@
# 项目说明
简单的go语言rtmp推流实现
+65
View File
@@ -0,0 +1,65 @@
package gortmppush
import (
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/container/flv"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol"
)
// RtmpAPI api接口类
type RtmpAPI struct {
setting *SettingEngine
logger logger.Logger
}
// NewAPI 创建一个api,设置相应的参数信息
func NewAPI(opts ...SettingFunc) *RtmpAPI {
api := &RtmpAPI{}
setting := &SettingEngine{}
for _, v := range opts {
v(setting)
}
if setting.loggerFactory == nil {
setting.loggerFactory = logger.NewDefaultFactory()
}
if setting.logLevel == logger.LogLevelDisabled {
setting.logLevel = logger.LogLevelInfo
}
api.logger = setting.loggerFactory.NewLogger(setting.logLevel)
api.setting = setting
return api
}
// ServeRtmp 创建一个rtmp服务,并监听响应的地址
func (api *RtmpAPI) ServeRtmp(addr string) error {
server := &Server{
handler: protocol.NewStreamHandler(api.logger),
logger: api.logger,
}
return server.Serve(addr)
}
// ServeRtmpTLS 创建一个rtmp服务,并监听响应的地址
func (api *RtmpAPI) ServeRtmpTLS(addr, tlsKey, tlsCrt string) error {
server := &Server{
handler: protocol.NewStreamHandler(api.logger),
logger: api.logger,
}
return server.ServeTLS(addr, tlsKey, tlsCrt)
}
// NewRtmpClient 创建一个rtmp客户端
func (api *RtmpAPI) NewRtmpClient() *RtmpClient {
client := &RtmpClient{
packetChan: make(chan *av.Packet, 16),
videoFirst: true,
audioFirst: true,
demuxer: flv.NewDemuxer(),
logger: api.logger,
}
return client
}
+135
View File
@@ -0,0 +1,135 @@
package av
import (
"io"
)
const (
TAG_AUDIO = 8
TAG_VIDEO = 9
TAG_SCRIPTDATAAMF0 = 18
TAG_SCRIPTDATAAMF3 = 0xf
)
const (
MetadatAMF0 = 0x12
MetadataAMF3 = 0xf
)
const (
SOUND_MP3 = 2
SOUND_NELLYMOSER_16KHZ_MONO = 4
SOUND_NELLYMOSER_8KHZ_MONO = 5
SOUND_NELLYMOSER = 6
SOUND_ALAW = 7
SOUND_MULAW = 8
SOUND_AAC = 10
SOUND_SPEEX = 11
//rtmp tag中只支持前面四种采样率,后面是为了程序处理方便自己添加
SOUND_RATE_5_5Khz = 0 //
SOUND_RATE_11Khz = 1 //11025 hz
SOUND_RATE_22Khz = 2 //22050 hz
SOUND_RATE_44Khz = 3
//自己添加
SOUND_RATE_7Khz = 4 //7350 hz
SOUND_RATE_8Khz = 5 //8000 hz
SOUND_RATE_12Khz = 6 //12000 hz
SOUND_RATE_16Khz = 7 //16000 hz
SOUND_RATE_24Khz = 8 //24000 hz
SOUND_RATE_32Khz = 9 //32000 hz
SOUND_RATE_48Khz = 10 //48000 hz
SOUND_RATE_64Khz = 11 //64000 hz
SOUND_RATE_88Khz = 12 //88200 hz
SOUND_RATE_96Khz = 13 // 96000 hz
SOUND_8BIT = 0
SOUND_16BIT = 1
SOUND_MONO = 0
SOUND_STEREO = 1
AAC_SEQHDR = 0
AAC_RAW = 1
)
const (
//视频tag的帧类型, 对于avc(h264)只用到了前面两个
FRAME_KEY = 1 // keyframe for avc, a seekable frame)
FRAME_INTER = 2 // inter frame (for avc, a non-seekable frame)
//3:disposable inter frame
//4:generated keyframe(reserved for server use only)
//5:vidoe info/command frame
//avc 视频封装格式
AVC_SEQHDR = 0 // avc sequence header
AVC_NALU = 1 // avc nalu
AVC_EOS = 2 // avc end of sequence
//avc视频编码id
VIDEO_JPEG = 1
VideoH263 = 2
VideoScreen = 3
VideoVP6 = 4
VideoVP6WithAlpha = 5
VideoScreenV2 = 6
VIDEO_H264 = 7
)
// Packet类型
const (
PacketTypeUnknow = 0
PacketTypeVideo = 1 //音频包
PacketTypeAudio = 2 //视频包
PacketTypeMetadata = 3 //数据包
)
var (
PUBLISH = "publish"
PLAY = "play"
)
// Header can be converted to AudioHeaderInfo or VideoHeaderInfo
type Packet struct {
PacketType uint32 // packet类型
TimeStamp uint32 // dts
StreamID uint32
VHeader VideoPacketHeader
AHeader AudioPacketHeader
Data []byte
}
// AudioPacketHeader comment
type AudioPacketHeader struct {
SoundFormat uint8
SoundRate uint8
SoundSize uint8
SoundType uint8
AACPacketType uint8
}
// VideoPacketHeader ...
type VideoPacketHeader struct {
FrameType uint8
AVCPacketType uint8
CodecID uint8
CompositionTime int32
}
type Demuxer interface {
Demux(*Packet) (ret *Packet, err error)
}
type Muxer interface {
Mux(*Packet, io.Writer) error
}
type SampleRater interface {
SampleRate() (int, error)
}
type CodecParser interface {
SampleRater
Parse(*Packet, io.Writer) error
}
+51
View File
@@ -0,0 +1,51 @@
package av
import (
"sync"
"time"
)
type RWBaser struct {
lock sync.Mutex
timeout time.Duration
PreTime time.Time
BaseTimestamp uint32
LastVideoTimestamp uint32
LastAudioTimestamp uint32
}
func NewRWBaser(duration time.Duration) RWBaser {
return RWBaser{
timeout: duration,
PreTime: time.Now(),
}
}
func (rw *RWBaser) BaseTimeStamp() uint32 {
return rw.BaseTimestamp
}
func (rw *RWBaser) CalcBaseTimestamp() {
if rw.LastAudioTimestamp > rw.LastVideoTimestamp {
rw.BaseTimestamp = rw.LastAudioTimestamp
} else {
rw.BaseTimestamp = rw.LastVideoTimestamp
}
}
func (rw *RWBaser) RecTimeStamp(timestamp, typeID uint32) {
if typeID == TAG_VIDEO {
rw.LastVideoTimestamp = timestamp
} else if typeID == TAG_AUDIO {
rw.LastAudioTimestamp = timestamp
}
}
func (rw *RWBaser) SetPreTime() {
rw.PreTime = time.Now()
}
func (rw *RWBaser) Alive() bool {
b := !(time.Now().Sub(rw.PreTime) >= rw.timeout)
return b
}
+356
View File
@@ -0,0 +1,356 @@
package gortmppush
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/container/flv"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/media/h264"
"github.com/H0RlZ0N/gortmppush/protocol/amf"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
// RtmpClient ...
type RtmpClient struct {
packetChan chan *av.Packet
conn *core.ConnClient
onPacketReceive func(*av.Packet)
onClosed func()
isPublish bool
videoFirst bool //first packet to send
audioFirst bool
demuxer *flv.Demuxer
logger logger.Logger
}
// NewRtmpClient comment
func NewRtmpClient(log logger.Logger) *RtmpClient {
return &RtmpClient{
packetChan: make(chan *av.Packet, 16),
videoFirst: true,
audioFirst: true,
demuxer: flv.NewDemuxer(),
logger: log,
}
}
// OpenPublish comment
func (c *RtmpClient) OpenPublish(URL string) (err error) {
c.conn = core.NewConnClient(c.logger)
if err = c.conn.Start(URL, "publish"); err != nil {
return
}
c.isPublish = true
return
}
// OpenPlay comment
func (c *RtmpClient) OpenPlay(URL string, onPacketReceive func(*av.Packet), onClosed func()) (err error) {
c.conn = core.NewConnClient(c.logger)
if err = c.conn.Start(URL, "play"); err != nil {
return
}
c.onPacketReceive = onPacketReceive
c.onClosed = onClosed
go c.streamPlayProc()
return
}
// Close 关闭连接,并回调onClosed
func (c *RtmpClient) Close() error {
c.conn.Close()
if c.onClosed != nil {
c.onClosed()
}
return nil
}
// SendPacket 发送数据包
func (c *RtmpClient) SendPacket(pkt *av.Packet) error {
if !c.isPublish {
return fmt.Errorf("It is not publish mode")
}
switch pkt.PacketType {
case av.PacketTypeAudio:
return c.sendAudioPacket(pkt)
case av.PacketTypeVideo:
return c.sendVideoPacket(pkt)
case av.PacketTypeMetadata:
return c.sendMetaPacket(pkt)
default:
return fmt.Errorf("Unknow packet type:%d", pkt.PacketType)
}
}
func (c *RtmpClient) sendAudioPacket(pkt *av.Packet) error {
var err error
if pkt.AHeader.SoundFormat == av.SOUND_AAC && c.audioFirst {
//如果音频是aac,需要先发送aac sequence header
sequencePkt := &av.Packet{
PacketType: av.PacketTypeAudio,
Data: flv.NewAACSequenceHeader(pkt.AHeader),
TimeStamp: pkt.TimeStamp,
}
if err = c.sendPacketData(sequencePkt.Data, sequencePkt.TimeStamp,
av.PacketTypeAudio); err != nil {
return fmt.Errorf("send aac sequence header failed. %v", err)
}
c.audioFirst = false
}
if pkt.Data, err = flv.PackAudioData(&pkt.AHeader, pkt.StreamID, pkt.Data,
pkt.TimeStamp); err != nil {
return fmt.Errorf("Pack audio failed. %v", err)
}
if err = c.sendPacketData(pkt.Data, pkt.TimeStamp, av.PacketTypeAudio); err != nil {
return fmt.Errorf("send packet failed, %v", err)
}
return nil
}
func (c *RtmpClient) sendVideoPacket(pkt *av.Packet) error {
var err error
if pkt.VHeader.CodecID == av.VIDEO_H264 {
// 如果是h264,第一帧要发送sequence header
if c.videoFirst {
var sps, pps []byte
nalus := h264.ParseNalus(pkt.Data)
for _, nalu := range nalus {
if naluType := nalu[0] & 0x1F; naluType == 7 {
sps = nalu
} else if naluType == 8 {
pps = nalu
}
}
if sps == nil || pps == nil {
c.logger.Warn("sps and pps need for first packet.")
return nil
}
//send flv sequence header
sequencePkt := &av.Packet{
PacketType: av.PacketTypeVideo,
Data: flv.NewAVCSequenceHeader(sps, pps, pkt.TimeStamp),
TimeStamp: pkt.TimeStamp,
}
if err = c.sendPacketData(sequencePkt.Data, sequencePkt.TimeStamp,
av.PacketTypeVideo); err != nil {
return fmt.Errorf("send flv sequence header failed, %v", err)
}
c.videoFirst = false
}
}
if pkt.Data, err = flv.PackVideoData(&pkt.VHeader, pkt.StreamID, pkt.Data,
pkt.TimeStamp); err != nil {
return fmt.Errorf("Pack video failed, %v", err)
}
if err = c.sendPacketData(pkt.Data, pkt.TimeStamp, av.PacketTypeVideo); err != nil {
return fmt.Errorf("send packet failed, %v", err)
}
return nil
}
func (c *RtmpClient) sendMetaPacket(pkt *av.Packet) error {
return fmt.Errorf("Mata data unsupport")
}
func (c *RtmpClient) sendPacketData(data []byte, timestamp uint32, packetType int) error {
if len(data) == 0 {
return fmt.Errorf("data length is zero")
}
var typeID uint32
switch packetType {
case av.PacketTypeVideo:
typeID = av.TAG_VIDEO
case av.PacketTypeAudio:
typeID = av.TAG_AUDIO
case av.PacketTypeMetadata:
typeID = av.TAG_SCRIPTDATAAMF0
default:
return fmt.Errorf("Unsupport packet type:%d", packetType)
}
// todo 其他的字段值是否有效
cs := core.ChunkStream{
Data: data,
Length: uint32(len(data)),
StreamID: c.conn.GetStreamID(),
Timestamp: timestamp,
TypeID: typeID,
}
if err := c.conn.Write(&cs); err != nil {
return err
} else if err := c.conn.Flush(); err != nil {
return err
}
return nil
}
// 从ChunkStream中解析音频和视频数据
func (c *RtmpClient) handleVideoAudio(cs *core.ChunkStream) error {
var pktType uint32
switch cs.TypeID {
case av.TAG_VIDEO:
pktType = av.PacketTypeVideo
case av.TAG_AUDIO:
pktType = av.PacketTypeAudio
case av.TAG_SCRIPTDATAAMF0, av.TAG_SCRIPTDATAAMF3:
pktType = av.PacketTypeMetadata
default:
return fmt.Errorf("Unknow chunk type:%d", cs.TypeID)
}
var err error
pkt := av.Packet{
Data: cs.Data,
StreamID: cs.StreamID,
TimeStamp: cs.Timestamp,
PacketType: pktType,
}
if err = c.demuxer.Demux(&pkt); err != nil {
return fmt.Errorf("Demux failed, %v", err)
}
switch pkt.PacketType {
case av.PacketTypeAudio: //处理音频数据
c.onPacketReceive(&pkt)
case av.PacketTypeVideo: //处理视频数据
switch pkt.VHeader.CodecID {
case av.VIDEO_H264:
// 如果是h264的sequence header,需要解析出sps和pps
if pkt.VHeader.FrameType == av.FRAME_KEY && pkt.VHeader.AVCPacketType == av.AVC_SEQHDR {
spss, ppss, err := flv.ParseAVCSequenceHeader(pkt.Data)
if err != nil {
return fmt.Errorf("Parse avc sequence header failed, %v", err)
}
//如果解析到多个sps和pps,只返回第一个sps和pps
if len(spss) > 0 {
pkt.Data = spss[0]
c.onPacketReceive(&pkt)
}
if len(ppss) > 0 {
pkt.Data = ppss[0]
c.onPacketReceive(&pkt)
}
return nil
}
default:
}
//解析后的数据格式为 4字节长度+nalue数据+4字节长度+nalu数据。。。
//解析出所以的nalu数据
index := 0
naluData := pkt.Data
for {
remain := len(naluData[index:])
if remain < 4 {
if remain != 0 {
c.logger.Warnf("Invalid data length, remain:%d", remain)
}
return nil
}
length := binary.BigEndian.Uint32(naluData[index:])
if length > uint32(remain-4) {
return fmt.Errorf("invalid data length:%d remain:%d", length, remain-4)
}
index += 4
pkt.Data = naluData[index : index+int(length)]
index += int(length)
c.onPacketReceive(&pkt)
}
case av.TAG_SCRIPTDATAAMF0, av.TAG_SCRIPTDATAAMF3:
return fmt.Errorf("TODO")
default:
return fmt.Errorf("unknow chunk stream type:%d", cs.TypeID)
}
return nil
}
func (c *RtmpClient) handleMetadata(cs *core.ChunkStream) (err error) {
var values []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == av.TAG_SCRIPTDATAAMF0 {
values, err = c.conn.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == av.TAG_SCRIPTDATAAMF3 {
values, err = c.conn.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("decode metadata failed, %v", err)
}
for _, v := range values {
switch v.(type) {
case string:
if v.(string) == "onMetadata" {
//说明该信息是描述视频信息的元数据,可以从afm.Object中获取到相印的属性值
}
case amf.Object:
for k, v1 := range v.(amf.Object) {
c.logger.Debugf("key:%s v:%v", k, v1)
}
default: //其他的忽略不处理
}
}
return nil
}
// 处理命令消息
func (c *RtmpClient) handleCommand(cs *core.ChunkStream) (err error) {
var values []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == 20 {
values, err = c.conn.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == 17 {
values, err = c.conn.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("Decode amf failed, %v", err)
}
for k, v := range values {
c.logger.Tracef("k:%d v:%v", k, v)
}
return nil
}
func (c *RtmpClient) streamPlayProc() {
defer c.Close()
for {
cs, err := c.conn.Read()
if err != nil {
c.logger.Errorf("Read chunk stream failed, %s", err.Error())
break
}
switch cs.TypeID {
case av.TAG_AUDIO, av.TAG_VIDEO:
if err := c.handleVideoAudio(cs); err != nil {
c.logger.Errorf("handle media data failed, %v", err)
}
case av.TAG_SCRIPTDATAAMF0, av.TAG_SCRIPTDATAAMF3:
c.logger.Debug("Receive a scriptdata.....")
if err := c.handleMetadata(cs); err != nil {
c.logger.Errorf("handle metadata failed, %v", err)
}
case 17, 20:
c.logger.Debug("Receive a command message.....")
if err := c.handleCommand(cs); err != nil {
c.logger.Errorf("handle command failed, %v", err)
}
default:
c.logger.Errorf("Unsupport type id:%d", cs.TypeID)
continue
}
}
}
+73
View File
@@ -0,0 +1,73 @@
package configure
import (
"encoding/json"
"fmt"
"io/ioutil"
)
/*
{
[
{
"application":"live",
"live":"on",
"hls":"on",
"static_push":["rtmp://xx/live"]
}
]
}
*/
type Application struct {
Appname string
Liveon string
Hlson string
Static_push []string
}
type ServerCfg struct {
Server []Application
}
var RtmpServercfg ServerCfg
func LoadConfig(configfilename string) error {
fmt.Printf("starting load configure file(%s)...\n", configfilename)
data, err := ioutil.ReadFile(configfilename)
if err != nil {
fmt.Printf("Read file %s error, %v\n", configfilename, err)
return err
}
fmt.Printf("loadconfig:%s\n", string(data))
err = json.Unmarshal(data, &RtmpServercfg)
if err != nil {
fmt.Printf("json.Unmarshal error, %v\n", err)
return err
}
fmt.Printf("get config json data:%v\n", RtmpServercfg)
return nil
}
func CheckAppName(appname string) bool {
for _, app := range RtmpServercfg.Server {
if (app.Appname == appname) && (app.Liveon == "on") {
return true
}
}
return false
}
func GetStaticPushUrlList(appname string) ([]string, bool) {
for _, app := range RtmpServercfg.Server {
if (app.Appname == appname) && (app.Liveon == "on") {
if len(app.Static_push) > 0 {
return app.Static_push, true
} else {
return nil, false
}
}
}
return nil, false
}
+89
View File
@@ -0,0 +1,89 @@
package flv
import (
"errors"
"fmt"
"github.com/H0RlZ0N/gortmppush/av"
)
// ErrAvcEndSEQ ...
var ErrAvcEndSEQ = errors.New("avc end sequence")
type Demuxer struct {
}
// NewDemuxer ...
func NewDemuxer() *Demuxer {
return &Demuxer{}
}
// DemuxH ...
func (d *Demuxer) DemuxH(p *av.Packet) (err error) {
var tag Tag
switch p.PacketType {
case av.PacketTypeAudio:
if _, err = tag.ParseAudioHeader(p.Data); err != nil {
return
}
p.AHeader = av.AudioPacketHeader{
SoundFormat: tag.mediat.soundFormat,
SoundRate: tag.mediat.soundRate,
SoundSize: tag.mediat.soundSize,
SoundType: tag.mediat.soundType,
AACPacketType: tag.mediat.aacPacketType,
}
case av.PacketTypeVideo:
if _, err = tag.ParseVideoHeader(p.Data); err != nil {
return
}
p.VHeader = av.VideoPacketHeader{
FrameType: tag.mediat.frameType,
AVCPacketType: tag.mediat.avcPacketType,
CodecID: tag.mediat.codecID,
CompositionTime: tag.mediat.compositionTime,
}
default:
//todo IsMetadata如何处理
return fmt.Errorf("Unsupport type")
}
return
}
// Demux ...
func (d *Demuxer) Demux(p *av.Packet) (err error) {
var (
tag Tag
n int
)
switch p.PacketType {
case av.PacketTypeAudio:
if n, err = tag.ParseAudioHeader(p.Data); err != nil {
return
}
p.AHeader = av.AudioPacketHeader{
SoundFormat: tag.mediat.soundFormat,
SoundRate: tag.mediat.soundRate,
SoundSize: tag.mediat.soundSize,
SoundType: tag.mediat.soundType,
AACPacketType: tag.mediat.aacPacketType,
}
case av.PacketTypeVideo:
if n, err = tag.ParseVideoHeader(p.Data); err != nil {
return
}
p.VHeader = av.VideoPacketHeader{
FrameType: tag.mediat.frameType,
AVCPacketType: tag.mediat.avcPacketType,
CodecID: tag.mediat.codecID,
CompositionTime: tag.mediat.compositionTime,
}
default:
return fmt.Errorf("Unsupport type:%d", p.PacketType)
}
if err != nil {
return
}
p.Data = p.Data[n:]
return
}
+161
View File
@@ -0,0 +1,161 @@
package flv
// import (
// "flag"
// "fmt"
// "os"
// "strings"
// "time"
// "github.com/H0RlZ0N/gortmppush/av"
// "github.com/H0RlZ0N/gortmppush/protocol/amf"
// "github.com/H0RlZ0N/gortmppush/utils"
// )
// var (
// flvHeader = []byte{0x46, 0x4c, 0x56, 0x01, 0x05, 0x00, 0x00, 0x00, 0x09}
// flvFile = flag.String("filFile", "./out.flv", "output flv file name")
// )
// func NewFlv(handler av.Handler, streamInfo av.StreamInfo) error {
// patths := strings.SplitN(streamInfo.Key, "/", 2)
// if len(patths) != 2 {
// return fmt.Errorf("Invalid key:%s", streamInfo.Key)
// }
// w, err := os.OpenFile(*flvFile, os.O_CREATE|os.O_RDWR, 0755)
// if err != nil {
// return fmt.Errorf("open file failed, %v", err)
// }
// //todo 文件句柄如何关闭
// writer := NewFLVWriter(patths[0], patths[1], streamInfo.URL, w)
// handler.HandleWriter(writer)
// writer.Wait()
// // close flv file
// writer.ctx.Close()
// return nil
// }
// const (
// headerLen = 11
// )
// type FLVWriter struct {
// av.RWBaser
// UID string
// app, title, url string
// buf []byte
// closed chan struct{}
// ctx *os.File
// }
// func NewFLVWriter(app, title, url string, ctx *os.File) *FLVWriter {
// ret := &FLVWriter{
// UID: utils.NewId(),
// app: app,
// title: title,
// url: url,
// ctx: ctx,
// RWBaser: av.NewRWBaser(time.Second * 10),
// closed: make(chan struct{}),
// buf: make([]byte, headerLen),
// }
// ret.ctx.Write(flvHeader)
// utils.PutI32BE(ret.buf[:4], 0)
// ret.ctx.Write(ret.buf[:4])
// return ret
// }
// func (writer *FLVWriter) Write(p *av.Packet) error {
// writer.RWBaser.SetPreTime()
// h := writer.buf[:headerLen]
// typeID := av.TAG_VIDEO
// switch p.PacketType {
// case av.PacketTypeVideo:
// typeID = av.TAG_VIDEO
// case av.PacketTypeAudio:
// typeID = av.TAG_AUDIO
// case av.PacketTypeMetadata:
// var err error
// typeID = av.TAG_SCRIPTDATAAMF0
// p.Data, err = amf.MetaDataReform(p.Data, amf.DEL)
// if err != nil {
// return err
// }
// }
// dataLen := len(p.Data)
// timestamp := p.TimeStamp
// timestamp += writer.BaseTimeStamp()
// writer.RWBaser.RecTimeStamp(timestamp, uint32(typeID))
// preDataLen := dataLen + headerLen
// timestampbase := timestamp & 0xffffff
// timestampExt := timestamp >> 24 & 0xff
// utils.PutU8(h[0:1], uint8(typeID))
// utils.PutI24BE(h[1:4], int32(dataLen))
// utils.PutI24BE(h[4:7], int32(timestampbase))
// utils.PutU8(h[7:8], uint8(timestampExt))
// if _, err := writer.ctx.Write(h); err != nil {
// return err
// }
// if _, err := writer.ctx.Write(p.Data); err != nil {
// return err
// }
// utils.PutI32BE(h[:4], int32(preDataLen))
// if _, err := writer.ctx.Write(h[:4]); err != nil {
// return err
// }
// return nil
// }
// func (writer *FLVWriter) Wait() {
// select {
// case <-writer.closed:
// return
// }
// }
// func (writer *FLVWriter) Close() {
// writer.ctx.Close()
// close(writer.closed)
// }
// func (writer *FLVWriter) StreamInfo() (ret av.StreamInfo) {
// ret.UID = writer.UID
// ret.URL = writer.url
// ret.Key = writer.app + "/" + writer.title
// return
// }
// type FlvDvr struct{}
// func (f *FlvDvr) NewWriter(streamInfo av.StreamInfo) (av.WriteCloser, error) {
// paths := strings.SplitN(streamInfo.Key, "/", 2)
// if len(paths) != 2 {
// return nil, fmt.Errorf("invalid key:%s", streamInfo.Key)
// }
// err := os.MkdirAll(paths[0], 0755)
// if err != nil {
// return nil, fmt.Errorf("mkdir failed, %v", err)
// }
// fileName := fmt.Sprintf("%s_%d.%s", streamInfo.Key, time.Now().Unix(), "flv")
// w, err := os.OpenFile(fileName, os.O_CREATE|os.O_RDWR, 0755)
// if err != nil {
// return nil, fmt.Errorf("open file failed, %v", err)
// }
// writer := NewFLVWriter(paths[0], paths[1], streamInfo.URL, w)
// return writer, nil
// }
+700
View File
@@ -0,0 +1,700 @@
package flv
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/media/aac"
"github.com/H0RlZ0N/gortmppush/media/h264"
"github.com/H0RlZ0N/gortmppush/utils"
)
type flvTag struct {
fType uint8 //8bit tag类型,包括音频tag8),视频tag(9),脚本tag(18)
dataSize uint32 //24bit 数据长度,从streamID后面算起
timeStamp uint32 //24bit 时间戳,单位是毫秒,对于脚本类型tag,总是为0
timeStampExtend uint8 //8bit 时间戳扩展,将时间戳扩展为4bytes,代表时间戳高8位
streamID uint32 //24bit always 0
}
/*
flv 格式
Header|PreviousTagSize0|Tag1|PreviousTagSize1|Tag2|PreviousTagSize2|...|TagN|PreviousTagSizeN|
Header
* Signature(3Byte) 固定为f l v三个字符
* Version(3Byte) 一般为1
* Flags(1Byte) 第0位和第2位分别表示video和audio存在的情况(1存在,0不存在)
* DataOffet(4Byte) 表示flv文件header的长度
Body由一系列的Tag和size组成,tag分为videoaudio和script
tag格式 tag type|tag data size|timestamp|timestamp extended|stream id|tag data
script tag 存放flv的MetaData信息,比如durationaudiodataratecreatorwidth等信息
video tag
第一个byte是视频信息, 格式如下:
帧类型 (4bits) 取值:
1: keyframe (for AVC, a seekable frame)
2: inter frame (for AVC, a non-seekable frame)
3: disposable inter frame (H.263 only)
4: generated keyframe (reserved for server use only)
5: video info/command frame
编码ID (4 bits) 取值:
1: JPEG (currently unused)
2: Sorenson H.263
3: Screen video
4: On2 VP6
5: On2 VP6 with alpha channel
6: Screen video version 2
7: AVC
接下来就是具体的video的流数据的封装
对于AVC(h264)格式的video,除了第一个字节的帧类型和编码id以外,从第二个字节开始,分别为
AVC包类型:AVCPacketType (8Bits) 取值:
0: AVC sequence header (这个必须在发送第一帧前发送,包含sps,pps解码相关信息)
1: AVC NALU
2: AVC end of sequence
CompositionTime (24Bits) 取值:
如果上面的AVCPacketType=0x01, 为相对时间戳;
其它: 均为0;
Data (n Bytes) 为负载数据 取值:
如果AVCPacketType=0x00, 为AVCDecorderConfigurationRecord;
如果AVCPacketType=0x01, 为NALUs;
如果AVCPacketType=0x02, 为空.
AVCDecoderConfigurationRecord详细说明:
一般第一个视频Tag会封装视频编码的总体描述信息(AVC sequence header), 就是AVCDecoderConfigurationRecord结构(ISO/IEC 14496-15 AVC file format中规定). 其结构如下:
aligned(8) class AVCDecoderConfigurationRecord {
unsigned int(8) configurationVersion = 1;
unsigned int(8) AVCProfileIndication;
unsigned int(8) profile_compatibility;
unsigned int(8) AVCLevelIndication;
bit(6) reserved = 111111b;
unsigned int(2) lengthSizeMinusOne;
bit(3) reserved = 111b;
unsigned int(5) numOfSequenceParameterSets;
for (i=0; i< numOfSequenceParameterSets; i++) {
unsigned int(16) sequenceParameterSetLength ;
bit(8*sequenceParameterSetLength) sequenceParameterSetNALUnit;
}
unsigned int(8) numOfPictureParameterSets;
for (i=0; i< numOfPictureParameterSets; i++) {
unsigned int(16) pictureParameterSetLength;
bit(8*pictureParameterSetLength) pictureParameterSetNALUnit;
}
}
*/
/*
audio tag
SoundFormat 4bit
SoundRage 2bit
SoundSize 1bit
SoundType 1bit
SoundData n bytes //音频数据
当SoundFormat == 10 时,SoundData的数据时AAC格式
AACAudioData格式如下
AACPacketType 8bit 0--aac sequence header 1--aac raw
Data n bytes
当 AACPacketType == 0 时,Data数据为AudioSpecificConfig
当 AAXCPacketType == 1 时,Data数据为AAC raw frame data
AudioSpecificConfig 格式为
audioObjectType 5bit
samplingFrequencyIndex 4bit
if samplingFrequencyIndex == 15 {
当samplingFrequencyIndex samplingFrequency直接指定采用率,否则不占位
samplingFrequency 24bit
}
channelConfiguration 4bit //输出声道信息,如双声道为2
GASpecificConfig 包含以下三项)
frameLengthFlag 1bit
dependsOnCoreCoder 1bit
extensionFlag 1bit
*/
// meidaTag 包含视频tag和音频tag
type mediaTag struct {
/*
SoundFormat: UB[4]
0 = Linear PCM, platform endian
1 = ADPCM
2 = MP3
3 = Linear PCM, little endian
4 = Nellymoser 16-kHz mono
5 = Nellymoser 8-kHz mono
6 = Nellymoser
7 = G.711 A-law logarithmic PCM
8 = G.711 mu-law logarithmic PCM
9 = reserved
10 = AAC
11 = Speex
14 = MP3 8-Khz
15 = Device-specific sound
Formats 7, 8, 14, and 15 are reserved for internal use
AAC is supported in Flash Player 9,0,115,0 and higher.
Speex is supported in Flash Player 10 and higher.
*/
soundFormat uint8
/*
SoundRate: UB[2]
Sampling rate
0 = 5.5-kHz For AAC: always 3
1 = 11-kHz
2 = 22-kHz
3 = 44-kHz
*/
soundRate uint8
/*
SoundSize: UB[1]
0 = snd8Bit
1 = snd16Bit
Size of each sample.
This parameter only pertains to uncompressed formats.
Compressed formats always decode to 16 bits internally
*/
soundSize uint8
/*
SoundType: UB[1]
0 = sndMono
1 = sndStereo
Mono or stereo sound For Nellymoser: always 0
For AAC: always 1
*/
soundType uint8
/*
0: AAC sequence header
1: AAC raw
*/
aacPacketType uint8
/*
1: keyframe (for AVC, a seekable frame)
2: inter frame (for AVC, a non- seekable frame)
3: disposable inter frame (H.263 only)
4: generated keyframe (reserved for server use only)
5: video info/command frame
*/
frameType uint8
/*
1: JPEG (currently unused)
2: Sorenson H.263
3: Screen video
4: On2 VP6
5: On2 VP6 with alpha channel
6: Screen video version 2
7: AVC
*/
codecID uint8
/*
0: AVC sequence header
1: AVC NALU
2: AVC end of sequence (lower level NALU sequence ender is not required or supported)
*/
avcPacketType uint8
compositionTime int32
}
type Tag struct {
flvt flvTag
mediat mediaTag
}
// ParseAudioHeader ...
func (tag *Tag) ParseAudioHeader(b []byte) (n int, err error) {
if len(b) < n+1 {
err = fmt.Errorf("invalid audiodata len=%d", len(b))
return
}
flags := b[0]
tag.mediat.soundFormat = flags >> 4
tag.mediat.soundRate = (flags >> 2) & 0x3
tag.mediat.soundSize = (flags >> 1) & 0x1
tag.mediat.soundType = flags & 0x1
n++
switch tag.mediat.soundFormat {
case av.SOUND_AAC:
tag.mediat.aacPacketType = b[1]
n++
}
return
}
// ParseVideoHeader ...
func (tag *Tag) ParseVideoHeader(b []byte) (n int, err error) {
if len(b) < n+5 {
err = fmt.Errorf("invalid videodata len=%d", len(b))
return
}
//第一个字节包含帧类型(4bit)和编码id(4bit)
flags := b[0]
tag.mediat.frameType = flags >> 4 //获取帧类型
tag.mediat.codecID = flags & 0xf //获取编码id
n++
if tag.mediat.codecID == av.VIDEO_H264 {
//如果编码id是avc,再获取avc的视频封装格式
tag.mediat.avcPacketType = b[1] //AVCPacketType 0-sequence header 1-nalue 2-end of sequence
//获取3个字节的compositionTime
for i := 2; i < 5; i++ {
tag.mediat.compositionTime = tag.mediat.compositionTime<<8 + int32(b[i])
}
n += 4
}
// if tag.mediat.frameType == av.FRAME_INTER || tag.mediat.frameType == av.FRAME_KEY {
// tag.mediat.avcPacketType = b[1]
// for i := 2; i < 5; i++ {
// tag.mediat.compositionTime = tag.mediat.compositionTime<<8 + int32(b[i])
// }
// n += 4
// }
return
}
// NewAVCSequenceHeader comment
func NewAVCSequenceHeader(sps, pps []byte, timeStamp uint32) []byte {
avcConfigRecord := h264.AVCDecoderConfigurationRecord(sps, pps)
tag := &Tag{
flvt: flvTag{
fType: av.TAG_VIDEO, //uint8 //8bit tag类型,包括音频tag8),视频tag(9),脚本tag(18)
dataSize: uint32(len(avcConfigRecord)), //uint32 //24bit 数据长度,从streamID后面算起
timeStamp: timeStamp, //uint32 //24bit 时间戳,单位是毫秒,对于脚本类型tag,总是为0
timeStampExtend: 0, //8bit 时间戳扩展,将时间戳扩展为4bytes,代表时间戳高8位
streamID: 0, //24bit always 0
},
mediat: mediaTag{
frameType: av.FRAME_KEY,
codecID: av.VIDEO_H264,
avcPacketType: av.AVC_SEQHDR,
compositionTime: 0,
},
}
index := 0
tagBuffer := muxerTagData(tag)
buffer := make([]byte, len(tagBuffer)+4+len(avcConfigRecord))
copy(buffer, tagBuffer)
index += len(tagBuffer)
//utils.PutU32BE(buffer[index:], uint32(len(avcConfigRecord)))
//index += 4
copy(buffer[index:], avcConfigRecord)
index += len(avcConfigRecord)
return buffer[:index]
}
// ParseAVCSequenceHeader 解析sps和pps
func ParseAVCSequenceHeader(data []byte) (spss, ppss [][]byte, err error) {
reader := bytes.NewReader(data)
var rb byte
if rb, err = reader.ReadByte(); err != nil {
err = fmt.Errorf("read version failed, %v", err)
return
}
//校验version,应该为1
if rb != 0x01 {
err = errors.New("version should be 0x01")
return
}
//读取接下来三个字节, 分别为avcProfileIndication, profileCompatility, avcLevelIndication
var apa [3]byte
if _, err = reader.Read(apa[0:]); err != nil {
err = fmt.Errorf("read apa failed, %v", err)
return
}
//跳过一个字节
if _, err = reader.Seek(1, io.SeekCurrent); err != nil {
err = fmt.Errorf("reader.Seek failed, %v", err)
return
}
if rb, err = reader.ReadByte(); err != nil {
err = fmt.Errorf("read number of sps failed, %v", err)
return
}
numberOfsps := int(rb & 0x1f)
for i := 0; i < numberOfsps; i++ {
var lengthBytes [2]byte
if _, err = reader.Read(lengthBytes[0:]); err != nil {
err = fmt.Errorf("read sps length failed, %v", err)
return
}
length := binary.BigEndian.Uint16(lengthBytes[0:])
sps := make([]byte, length)
if _, err = reader.Read(sps[0:]); err != nil {
err = fmt.Errorf("read sps failed, %v", err)
return
}
spss = append(spss, sps)
}
//读取pps长度
if rb, err = reader.ReadByte(); err != nil {
err = fmt.Errorf("read number of sps failed, %v", err)
return
}
numberOfpps := int(rb)
for i := 0; i < numberOfpps; i++ {
var lengthBytes [2]byte
if _, err = reader.Read(lengthBytes[0:]); err != nil {
err = fmt.Errorf("read sps length failed, %v", err)
return
}
length := binary.BigEndian.Uint16(lengthBytes[0:])
pps := make([]byte, length)
if _, err = reader.Read(pps[0:]); err != nil {
err = fmt.Errorf("read sps failed, %v", err)
return
}
ppss = append(ppss, pps)
}
return
}
// NewAACSequenceHeader comment
func NewAACSequenceHeader(ah av.AudioPacketHeader) []byte {
var (
objectType uint8
samplingFrequenceIndex uint8
channelConfiguration uint8
)
objectType = 2 //AAC_LC
switch ah.SoundRate {
case av.SOUND_RATE_5_5Khz, av.SOUND_RATE_7Khz:
samplingFrequenceIndex = 4 //不支持5.5kHz, 7Khz
case av.SOUND_RATE_8Khz:
samplingFrequenceIndex = 11
case av.SOUND_RATE_11Khz:
samplingFrequenceIndex = 10
case av.SOUND_RATE_12Khz:
samplingFrequenceIndex = 9
case av.SOUND_RATE_16Khz:
samplingFrequenceIndex = 8
case av.SOUND_RATE_22Khz:
samplingFrequenceIndex = 7
case av.SOUND_RATE_24Khz:
samplingFrequenceIndex = 6
case av.SOUND_RATE_32Khz:
samplingFrequenceIndex = 5
case av.SOUND_RATE_44Khz:
samplingFrequenceIndex = 4
case av.SOUND_RATE_48Khz:
samplingFrequenceIndex = 3
case av.SOUND_RATE_64Khz:
samplingFrequenceIndex = 2
case av.SOUND_RATE_88Khz:
samplingFrequenceIndex = 1
case av.SOUND_RATE_96Khz:
samplingFrequenceIndex = 0
default:
samplingFrequenceIndex = 4
}
if ah.SoundType == av.SOUND_MONO {
channelConfiguration = 1
} else if ah.SoundType == av.SOUND_STEREO {
channelConfiguration = 2
} else {
channelConfiguration = 2
}
specificConfig := aac.SpecificConfig(objectType, samplingFrequenceIndex, channelConfiguration)
tag := &Tag{
flvt: flvTag{
fType: av.TAG_AUDIO,
// dataSize: uint32(len(specificConfig)),
// timeStamp: timeStamp,
// timeStampExtend: 0,
// streamID: 0,
},
mediat: mediaTag{
soundFormat: ah.SoundFormat, //aac
soundRate: ah.SoundRate, //44KHz
soundSize: ah.SoundSize,
soundType: ah.SoundType, //单声道
aacPacketType: av.AAC_SEQHDR,
},
}
index := 0
tagBuffer := muxerTagData(tag)
buffer := make([]byte, len(tagBuffer)+len(specificConfig))
copy(buffer, tagBuffer)
index += len(tagBuffer)
copy(buffer[index:], specificConfig)
index += len(specificConfig)
return buffer[:index]
}
// PackVideoData 打包音数据到buffer中,按照flv的video tag的格式打包
func PackVideoData(header *av.VideoPacketHeader, streamID uint32, src []byte,
timeStamp uint32) ([]byte, error) {
var tag *Tag
switch header.CodecID {
case av.VIDEO_H264:
if len(src) >= 4 && bytes.Compare(src[0:4], h264.StartCode4) == 0 {
src = src[4:]
}
if len(src) == 0 {
return nil, fmt.Errorf("invalid data")
}
//获取naluType类型
frameType := uint8(av.FRAME_INTER)
naluType := src[0] & 0x1F
if naluType == 7 || naluType == 8 || naluType == 5 {
frameType = uint8(av.FRAME_KEY)
}
tag = &Tag{
flvt: flvTag{
fType: av.TAG_VIDEO,
dataSize: uint32(len(src)), //在用rtmp协议发送是,改字段好像不起作用,正常情况是后面mediaTag+数据的长度
timeStamp: timeStamp,
timeStampExtend: 0, // todo
streamID: 0, // todo
},
mediat: mediaTag{
frameType: frameType,
codecID: header.CodecID,
avcPacketType: header.AVCPacketType,
compositionTime: 0, // todo
},
}
case av.VIDEO_JPEG, av.VideoH263:
tag = &Tag{
flvt: flvTag{
fType: av.TAG_VIDEO,
dataSize: uint32(len(src)),
timeStamp: timeStamp,
timeStampExtend: 0, // todo 这个需要处理一下
streamID: 0, //todo 这个需要设置
},
mediat: mediaTag{
frameType: av.FRAME_KEY, //jpeg都认为是key frame
codecID: header.CodecID,
avcPacketType: 0, // jpeg,这个字段不起作用
compositionTime: 0, // jpeg,这个字段不起作用
},
}
default:
return nil, fmt.Errorf("unsupport code id:%d", header.CodecID)
}
index := 0
//生成tagHeader 部分
tagBuffer := muxerTagData(tag)
if header.CodecID == av.VIDEO_H264 {
naluType := src[0] & 0x1F
if naluType == 7 || naluType == 8 {
//如果是7或8,应该是I帧,把里面的nalu单元都提取出来,打包
//naluLength nalu naluLength nalu
nalus := h264.ParseNalusN(src, 3)
buffer := make([]byte, len(tagBuffer)+len(src)+len(nalus)*4)
copy(buffer[index:], tagBuffer)
index += len(tagBuffer)
for _, nalu := range nalus {
if len := len(nalu); len > 0 {
utils.PutU32BE(buffer[index:], uint32(len))
index += 4
copy(buffer[index:], nalu)
index += len
}
}
return buffer[:index], nil
}
}
if len(src) == 0 {
return nil, fmt.Errorf("invalid data")
}
//否则都统一打包发送
//创建buffer len(tagBuffer) + 4字节长度 + 数据长度
buffer := make([]byte, len(tagBuffer)+4+len(src))
//拷贝tag 头数据
copy(buffer[index:], tagBuffer)
index += len(tagBuffer)
//设置数据长度
utils.PutU32BE(buffer[index:], uint32(len(src)))
index += 4
//拷贝数据
copy(buffer[index:], src)
index += len(src)
return buffer[:index], nil
}
// //NewAVCNaluData 把nalu单元打包成rtmp的payload
// func NewAVCNaluData(src []byte, timeStamp uint32) (buffer []byte) {
// //nalu单元至少要大于4个字节,包括start code(一帧开始的起始码应该是4位)
// if len(src) <= 4 {
// buffer = make([]byte, 0)
// return
// }
// //获取naluType类型
// frameType := uint8(av.FRAME_INTER)
// naluType := src[4] & 0x1F
// if naluType == 7 || naluType == 8 || naluType == 5 {
// frameType = uint8(av.FRAME_KEY)
// }
// tag := &Tag{
// flvt: flvTag{
// fType: av.TAG_VIDEO,
// dataSize: uint32(len(src)), //在用rtmp协议发送是,改字段好像不起作用,正常情况是后面mediaTag+数据的长度
// timeStamp: timeStamp,
// timeStampExtend: 0,
// streamID: 0,
// },
// mediat: mediaTag{
// frameType: frameType,
// codecID: av.VIDEO_H264,
// avcPacketType: av.AVC_NALU,
// compositionTime: 0,
// },
// }
// //生成tagHeader 部分
// tagBuffer := muxerTagData(tag)
// index := 0
// if naluType == 7 || naluType == 8 {
// //如果是7或8,应该是I帧,把里面的nalu单元都提取出来,打包
// //naluLength nalu naluLength nalu
// nalus := h264.ParseNalusN(src, 3)
// buffer = make([]byte, len(tagBuffer)+len(src)+len(nalus)*4)
// copy(buffer[index:], tagBuffer)
// index += len(tagBuffer)
// for _, nalu := range nalus {
// if len := len(nalu); len > 0 {
// utils.PutU32BE(buffer[index:], uint32(len))
// index += 4
// copy(buffer[index:], nalu)
// index += len
// }
// }
// return buffer[:index]
// }
// src = src[4:] //去掉start code
// //创建buffer len(tagBuffer) + 4字节长度 + 数据长度
// buffer = make([]byte, len(tagBuffer)+4+len(src))
// //拷贝tag 头数据
// copy(buffer[index:], tagBuffer)
// index += len(tagBuffer)
// //设置数据长度
// utils.PutU32BE(buffer[index:], uint32(len(src)))
// index += 4
// //拷贝数据
// copy(buffer[index:], src)
// index += len(src)
// return buffer[:index]
// }
// PackAudioData 打包音频数据
func PackAudioData(ah *av.AudioPacketHeader, streamID uint32, src []byte,
timeStamp uint32) ([]byte, error) {
if ah.SoundFormat != av.SOUND_AAC {
return nil, fmt.Errorf("code %d not support", ah.SoundFormat)
}
tag := &Tag{
flvt: flvTag{
fType: av.TAG_AUDIO,
dataSize: uint32(len(src)), //可能由上层协议作为一帧的分割,该字段没有效果
timeStamp: timeStamp,
timeStampExtend: 0, //todo
streamID: streamID,
},
mediat: mediaTag{
soundFormat: ah.SoundFormat,
soundRate: ah.SoundRate,
soundSize: ah.SoundSize,
soundType: ah.SoundType,
aacPacketType: av.AAC_RAW,
},
}
tagBuffer := muxerTagData(tag)
index := 0
buffer := make([]byte, len(tagBuffer)+len(src))
copy(buffer[index:], tagBuffer)
index += len(tagBuffer)
copy(buffer[index:], src)
index += len(src)
return buffer[:index], nil
}
// NewAACData comment
func NewAACData(ah av.AudioPacketHeader, src []byte, timeStamp uint32) (buffer []byte) {
tag := &Tag{
flvt: flvTag{
fType: av.TAG_AUDIO,
dataSize: uint32(len(src)), //可能由上层协议作为一帧的分割,该字段没有效果
timeStamp: timeStamp,
timeStampExtend: 0,
streamID: 0,
},
mediat: mediaTag{
soundFormat: ah.SoundFormat,
soundRate: ah.SoundRate,
soundSize: ah.SoundSize,
soundType: ah.SoundType,
aacPacketType: av.AAC_RAW,
},
}
tagBuffer := muxerTagData(tag)
index := 0
buffer = make([]byte, len(tagBuffer)+len(src))
copy(buffer[index:], tagBuffer)
index += len(tagBuffer)
copy(buffer[index:], src)
index += len(src)
return buffer[:index]
}
// MuxerTagData 打包tag头和数据部分,在用rtmp协议发送时,tag头只包含了mediaTag,没有flvTag数据
// 应该时flvTag这部分功能被chunk的功能替代了,不用flvTag也可以知道一个完整的帧,如果打包成flv文件时,
// flvTag不能省略
func muxerTagData(tag *Tag) []byte {
n := 0
buffer := make([]byte, 5) //16是按最大的长度来计算,aac有16个字节长度头
if tag.flvt.fType == av.TAG_VIDEO {
buffer[n] = (tag.mediat.frameType << 4) | (tag.mediat.codecID & 0x0F) //帧类型 4bit 编码id 4bit
n++
if tag.mediat.codecID == av.VIDEO_H264 {
//如果是h264,有额外的封装
utils.PutU8(buffer[n:], tag.mediat.avcPacketType) //AVCPacketType 8bit
n++
utils.PutU24BE(buffer[n:], uint32(tag.mediat.compositionTime)) //CompositionTime 24bit
n += 3
}
} else if tag.flvt.fType == av.TAG_AUDIO {
//音频格式 4bit
//采样率 2bit
//采样长度 1bit
//音频类型 1bit
buffer[n] = (tag.mediat.soundFormat << 4) | (tag.mediat.soundRate << 2 & 0x0C) |
(tag.mediat.soundSize << 1 & 0x02) | (tag.mediat.soundType & 0x01)
n++
if tag.mediat.soundFormat == av.SOUND_AAC {
utils.PutU8(buffer[n:], tag.mediat.aacPacketType) //AACPacketType
n++
}
}
//剩下的都认为是脚本tag,直接写入数据
return buffer[:n]
}
+1
View File
@@ -0,0 +1 @@
package flv
+78
View File
@@ -0,0 +1,78 @@
package ts
func GenCrc32(src []byte) uint32 {
crcTable := []uint32{
0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9,
0x130476dc, 0x17c56b6b, 0x1a864db2, 0x1e475005,
0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61,
0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd,
0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9,
0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75,
0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011,
0x791d4014, 0x7ddc5da3, 0x709f7b7a, 0x745e66cd,
0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039,
0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5,
0xbe2b5b58, 0xbaea46ef, 0xb7a96036, 0xb3687d81,
0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d,
0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49,
0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95,
0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1,
0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d,
0x34867077, 0x30476dc0, 0x3d044b19, 0x39c556ae,
0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072,
0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16,
0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca,
0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde,
0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02,
0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, 0x53dc6066,
0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba,
0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e,
0xbfa1b04b, 0xbb60adfc, 0xb6238b25, 0xb2e29692,
0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6,
0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a,
0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e,
0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2,
0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686,
0xd5b88683, 0xd1799b34, 0xdc3abded, 0xd8fba05a,
0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637,
0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb,
0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f,
0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53,
0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47,
0x36194d42, 0x32d850f5, 0x3f9b762c, 0x3b5a6b9b,
0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff,
0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623,
0xf12f560e, 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7,
0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b,
0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f,
0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3,
0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7,
0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b,
0x9b3660c6, 0x9ff77d71, 0x92b45ba8, 0x9675461f,
0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3,
0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640,
0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c,
0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8,
0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24,
0x119b4be9, 0x155a565e, 0x18197087, 0x1cd86d30,
0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec,
0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088,
0x2497d08d, 0x2056cd3a, 0x2d15ebe3, 0x29d4f654,
0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0,
0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c,
0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18,
0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4,
0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0,
0x9abc8bd5, 0x9e7d9662, 0x933eb0bb, 0x97ffad0c,
0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668,
0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4}
j := byte(0)
crc32 := uint32(0xFFFFFFFF)
for i := 0; i < len(src); i++ {
j = (byte(crc32>>24) ^ src[i]) & 0xff
crc32 = uint32(uint32(crc32<<8) ^ uint32(crcTable[j]))
}
return crc32
}
+364
View File
@@ -0,0 +1,364 @@
package ts
import (
"io"
"github.com/H0RlZ0N/gortmppush/av"
)
const (
tsDefaultDataLen = 184
tsPacketLen = 188
h264DefaultHZ = 90
videoPID = 0x100
audioPID = 0x101
videoSID = 0xe0
audioSID = 0xc0
)
type Muxer struct {
videoCc byte
audioCc byte
patCc byte
pmtCc byte
pat [tsPacketLen]byte
pmt [tsPacketLen]byte
tsPacket [tsPacketLen]byte
}
func NewMuxer() *Muxer {
return &Muxer{}
}
func (muxer *Muxer) Mux(p *av.Packet, w io.Writer) error {
first := true
wBytes := 0
pesIndex := 0
tmpLen := byte(0)
dataLen := byte(0)
var pes pesHeader
dts := int64(p.TimeStamp) * int64(h264DefaultHZ)
pts := dts
pid := audioPID
var videoH av.VideoPacketHeader
if p.PacketType == av.PacketTypeVideo {
pid = videoPID
pts = dts + int64(p.VHeader.CompositionTime)*int64(h264DefaultHZ)
}
err := pes.packet(p, pts, dts)
if err != nil {
return err
}
pesHeaderLen := pes.len
packetBytesLen := len(p.Data) + int(pesHeaderLen)
for {
if packetBytesLen <= 0 {
break
}
if p.PacketType == av.PacketTypeVideo {
muxer.videoCc++
if muxer.videoCc > 0xf {
muxer.videoCc = 0
}
} else {
muxer.audioCc++
if muxer.audioCc > 0xf {
muxer.audioCc = 0
}
}
i := byte(0)
//sync byte
muxer.tsPacket[i] = 0x47
i++
//error indicator, unit start indicator,ts priority,pid
muxer.tsPacket[i] = byte(pid >> 8) //pid high 5 bits
if first {
muxer.tsPacket[i] = muxer.tsPacket[i] | 0x40 //unit start indicator
}
i++
//pid low 8 bits
muxer.tsPacket[i] = byte(pid)
i++
//scram control, adaptation control, counter
if p.PacketType == av.PacketTypeVideo {
muxer.tsPacket[i] = 0x10 | byte(muxer.videoCc&0x0f)
} else {
muxer.tsPacket[i] = 0x10 | byte(muxer.audioCc&0x0f)
}
i++
//关键帧需要加pcr
if first && p.PacketType == av.PacketTypeVideo &&
videoH.FrameType == av.FRAME_KEY {
muxer.tsPacket[3] |= 0x20
muxer.tsPacket[i] = 7
i++
muxer.tsPacket[i] = 0x50
i++
muxer.writePcr(muxer.tsPacket[0:], i, dts)
i += 6
}
//frame data
if packetBytesLen >= tsDefaultDataLen {
dataLen = tsDefaultDataLen
if first {
dataLen -= (i - 4)
}
} else {
muxer.tsPacket[3] |= 0x20 //have adaptation
remainBytes := byte(0)
dataLen = byte(packetBytesLen)
if first {
remainBytes = tsDefaultDataLen - dataLen - (i - 4)
} else {
remainBytes = tsDefaultDataLen - dataLen
}
muxer.adaptationBufInit(muxer.tsPacket[i:], byte(remainBytes))
i += remainBytes
}
if first && i < tsPacketLen && pesHeaderLen > 0 {
tmpLen = tsPacketLen - i
if pesHeaderLen <= tmpLen {
tmpLen = pesHeaderLen
}
copy(muxer.tsPacket[i:], pes.data[pesIndex:pesIndex+int(tmpLen)])
i += tmpLen
packetBytesLen -= int(tmpLen)
dataLen -= tmpLen
pesHeaderLen -= tmpLen
pesIndex += int(tmpLen)
}
if i < tsPacketLen {
tmpLen = tsPacketLen - i
if tmpLen <= dataLen {
dataLen = tmpLen
}
copy(muxer.tsPacket[i:], p.Data[wBytes:wBytes+int(dataLen)])
wBytes += int(dataLen)
packetBytesLen -= int(dataLen)
}
if w != nil {
if _, err := w.Write(muxer.tsPacket[0:]); err != nil {
return err
}
}
first = false
}
return nil
}
// PAT return pat data
func (muxer *Muxer) PAT() []byte {
i := 0
remainByte := 0
tsHeader := []byte{0x47, 0x40, 0x00, 0x10, 0x00}
patHeader := []byte{0x00, 0xb0, 0x0d, 0x00, 0x01, 0xc1, 0x00, 0x00, 0x00, 0x01, 0xf0, 0x01}
if muxer.patCc > 0xf {
muxer.patCc = 0
}
tsHeader[3] |= muxer.patCc & 0x0f
muxer.patCc++
copy(muxer.pat[i:], tsHeader)
i += len(tsHeader)
copy(muxer.pat[i:], patHeader)
i += len(patHeader)
crc32Value := GenCrc32(patHeader)
muxer.pat[i] = byte(crc32Value >> 24)
i++
muxer.pat[i] = byte(crc32Value >> 16)
i++
muxer.pat[i] = byte(crc32Value >> 8)
i++
muxer.pat[i] = byte(crc32Value)
i++
remainByte = int(tsPacketLen - i)
for j := 0; j < remainByte; j++ {
muxer.pat[i+j] = 0xff
}
return muxer.pat[0:]
}
// PMT return pmt data
func (muxer *Muxer) PMT(soundFormat byte, hasVideo bool) []byte {
i := int(0)
j := int(0)
var progInfo []byte
remainBytes := int(0)
tsHeader := []byte{0x47, 0x50, 0x01, 0x10, 0x00}
pmtHeader := []byte{0x02, 0xb0, 0xff, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x00}
if !hasVideo {
pmtHeader[9] = 0x01
progInfo = []byte{0x0f, 0xe1, 0x01, 0xf0, 0x00}
} else {
progInfo = []byte{0x1b, 0xe1, 0x00, 0xf0, 0x00, //h264 or h265*
0x0f, 0xe1, 0x01, 0xf0, 0x00, //mp3 or aac
}
}
pmtHeader[2] = byte(len(progInfo) + 9 + 4)
if muxer.pmtCc > 0xf {
muxer.pmtCc = 0
}
tsHeader[3] |= muxer.pmtCc & 0x0f
muxer.pmtCc++
if soundFormat == 2 ||
soundFormat == 14 {
if hasVideo {
progInfo[5] = 0x4
} else {
progInfo[0] = 0x4
}
}
copy(muxer.pmt[i:], tsHeader)
i += len(tsHeader)
copy(muxer.pmt[i:], pmtHeader)
i += len(pmtHeader)
copy(muxer.pmt[i:], progInfo[0:])
i += len(progInfo)
crc32Value := GenCrc32(muxer.pmt[5 : 5+len(pmtHeader)+len(progInfo)])
muxer.pmt[i] = byte(crc32Value >> 24)
i++
muxer.pmt[i] = byte(crc32Value >> 16)
i++
muxer.pmt[i] = byte(crc32Value >> 8)
i++
muxer.pmt[i] = byte(crc32Value)
i++
remainBytes = int(tsPacketLen - i)
for j = 0; j < remainBytes; j++ {
muxer.pmt[i+j] = 0xff
}
return muxer.pmt[0:]
}
func (muxer *Muxer) adaptationBufInit(src []byte, remainBytes byte) {
src[0] = byte(remainBytes - 1)
if remainBytes == 1 {
} else {
src[1] = 0x00
for i := 2; i < len(src); i++ {
src[i] = 0xff
}
}
return
}
func (muxer *Muxer) writePcr(b []byte, i byte, pcr int64) error {
b[i] = byte(pcr >> 25)
i++
b[i] = byte((pcr >> 17) & 0xff)
i++
b[i] = byte((pcr >> 9) & 0xff)
i++
b[i] = byte((pcr >> 1) & 0xff)
i++
b[i] = byte(((pcr & 0x1) << 7) | 0x7e)
i++
b[i] = 0x00
return nil
}
type pesHeader struct {
len byte
data [tsPacketLen]byte
}
// pesPacket return pes packet
func (header *pesHeader) packet(p *av.Packet, pts, dts int64) error {
//PES header
i := 0
header.data[i] = 0x00
i++
header.data[i] = 0x00
i++
header.data[i] = 0x01
i++
sid := audioSID
if p.PacketType == av.PacketTypeVideo {
sid = videoSID
}
header.data[i] = byte(sid)
i++
flag := 0x80
ptslen := 5
dtslen := ptslen
headerSize := ptslen
if p.PacketType == av.PacketTypeVideo && pts != dts {
flag |= 0x40
headerSize += 5 //add dts
}
size := len(p.Data) + headerSize + 3
if size > 0xffff {
size = 0
}
header.data[i] = byte(size >> 8)
i++
header.data[i] = byte(size)
i++
header.data[i] = 0x80
i++
header.data[i] = byte(flag)
i++
header.data[i] = byte(headerSize)
i++
header.writeTs(header.data[0:], i, flag>>6, pts)
i += ptslen
if p.PacketType == av.PacketTypeVideo && pts != dts {
header.writeTs(header.data[0:], i, 1, dts)
i += dtslen
}
header.len = byte(i)
return nil
}
func (header *pesHeader) writeTs(src []byte, i int, fb int, ts int64) {
val := uint32(0)
if ts > 0x1ffffffff {
ts -= 0x1ffffffff
}
val = uint32(fb<<4) | ((uint32(ts>>30) & 0x07) << 1) | 1
src[i] = byte(val)
i++
val = ((uint32(ts>>15) & 0x7fff) << 1) | 1
src[i] = byte(val >> 8)
i++
src[i] = byte(val)
i++
val = (uint32(ts&0x7fff) << 1) | 1
src[i] = byte(val >> 8)
i++
src[i] = byte(val)
}
+52
View File
@@ -0,0 +1,52 @@
package ts
import (
"testing"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/stretchr/testify/assert"
)
type TestWriter struct {
buf []byte
count int
}
// Write write p to w.buf
func (w *TestWriter) Write(p []byte) (int, error) {
w.count++
w.buf = p
return len(p), nil
}
func TestTSEncoder(t *testing.T) {
at := assert.New(t)
m := NewMuxer()
w := &TestWriter{}
data := []byte{0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d, 0x0b, 0x6d, 0x44, 0xae, 0x81,
0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04, 0x80, 0x00, 0x5b, 0xb7,
0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00, 0x06, 0x00, 0x38,
}
p := av.Packet{
PacketType: av.PacketTypeUnknow,
Data: data,
}
err := m.Mux(&p, w)
at.Equal(err, nil)
at.Equal(w.count, 1)
at.Equal(w.buf, []byte{0x47, 0x41, 0x01, 0x31, 0x81, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x01, 0xc0, 0x00, 0x30,
0x80, 0x80, 0x05, 0x21, 0x00, 0x01, 0x00, 0x01, 0xaf, 0x01, 0x21, 0x19, 0xd3, 0x40, 0x7d,
0x0b, 0x6d, 0x44, 0xae, 0x81, 0x08, 0x00, 0x89, 0xa0, 0x3e, 0x85, 0xb6, 0x92, 0x57, 0x04,
0x80, 0x00, 0x5b, 0xb7, 0x78, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x38, 0x30, 0x00,
0x06, 0x00, 0x38})
}
+9
View File
@@ -0,0 +1,9 @@
all:
gofmt -d -l -e -w .
go mod tidy
go build
run:
demo.exe
BIN
View File
Binary file not shown.
+119
View File
@@ -0,0 +1,119 @@
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"time"
"github.com/H0RlZ0N/gortmppush"
"github.com/H0RlZ0N/gortmppush/av"
)
var startCode []byte = []byte{0x00, 0x00, 0x00, 0x01}
func PushH264(client *gortmppush.RtmpClient) {
data, err := ioutil.ReadFile("test.h264")
if err != nil {
panic(err)
}
bcheck := true
timeStamp := 0
pre := 0
index := 0
for {
if bytes.Compare(data[index:index+4], startCode) == 0 {
if index > pre {
if bcheck {
preNaluType := data[pre+4] & 0x1F
naluType := data[index+4] & 0x1F
if preNaluType == 7 {
if naluType != 7 && naluType != 8 {
bcheck = false
}
index += 4
continue
}
}
bcheck = true
pkt := &av.Packet{
PacketType: av.PacketTypeVideo,
TimeStamp: uint32(timeStamp),
Data: make([]byte, index-pre),
StreamID: 0,
VHeader: av.VideoPacketHeader{
FrameType: av.FRAME_KEY,
CodecID: av.VIDEO_H264,
AVCPacketType: av.AVC_NALU,
CompositionTime: 0,
},
}
copy(pkt.Data, data[pre:index])
if err := client.SendPacket(pkt); err != nil {
panic(err)
}
time.Sleep(time.Millisecond * 100)
timeStamp += 100
pre = index
index += 4
continue
}
}
index++
if index+4 >= len(data) {
break
}
}
}
func PushJPEG(client *gortmppush.RtmpClient) {
index := 0
timeStamp := 0
for {
fileName := fmt.Sprintf("/Users/fabojiang/Documents/raw_jpeg_to_tuya/gc0308_img_%d.data", index)
data, err := ioutil.ReadFile(fileName)
if err != nil {
panic(err)
}
index++
if index >= 100 {
index = 0
}
pkt := &av.Packet{
PacketType: av.PacketTypeVideo,
TimeStamp: uint32(timeStamp),
Data: make([]byte, len(data)),
StreamID: 0,
VHeader: av.VideoPacketHeader{
FrameType: av.FRAME_KEY,
CodecID: av.VIDEO_JPEG,
CompositionTime: 0,
},
}
copy(pkt.Data, data)
if err := client.SendPacket(pkt); err != nil {
panic(err)
}
fmt.Println("Debug.... send jpeg.... ", len(data))
time.Sleep(time.Millisecond * 40)
timeStamp += 40
}
}
func main() {
flag.Parse()
api := gortmppush.NewAPI()
client := api.NewRtmpClient()
rtmpURL := fmt.Sprintf("rtmp://127.0.0.1:1935/live/one")
if err := client.OpenPublish(rtmpURL); err != nil {
panic(err)
}
PushH264(client)
}
BIN
View File
Binary file not shown.
BIN
View File
Binary file not shown.
+8
View File
@@ -0,0 +1,8 @@
module github.com/H0RlZ0N/gortmppush
go 1.13
require (
github.com/satori/go.uuid v1.2.0
github.com/stretchr/testify v1.4.0
)
+88
View File
@@ -0,0 +1,88 @@
package logger
import (
"fmt"
"runtime"
"strings"
"time"
)
// DefaultFactory ...
type DefaultFactory struct {
}
type defaultLogger struct {
level LogLevel
}
// NewDefaultFactory ...
func NewDefaultFactory() *DefaultFactory {
return &DefaultFactory{}
}
// NewLogger ...
func (f *DefaultFactory) NewLogger(scope LogLevel) Logger {
return defaultLogger{
level: scope,
}
}
func (log defaultLogger) Trace(msg string) {
log.output(2, LogLevelTrace, msg)
}
func (log defaultLogger) Tracef(format string, args ...interface{}) {
log.output(2, LogLevelTrace, fmt.Sprintf(format, args...))
}
func (log defaultLogger) Debug(msg string) {
log.output(2, LogLevelDebug, msg)
}
func (log defaultLogger) Debugf(format string, args ...interface{}) {
log.output(2, LogLevelDebug, fmt.Sprintf(format, args...))
}
func (log defaultLogger) Info(msg string) {
log.output(2, LogLevelInfo, msg)
}
func (log defaultLogger) Infof(format string, args ...interface{}) {
log.output(2, LogLevelInfo, fmt.Sprintf(format, args...))
}
func (log defaultLogger) Warn(msg string) {
log.output(2, LogLevelWarn, msg)
}
func (log defaultLogger) Warnf(format string, args ...interface{}) {
log.output(2, LogLevelWarn, fmt.Sprintf(format, args...))
}
func (log defaultLogger) Error(msg string) {
log.output(2, LogLevelError, msg)
}
func (log defaultLogger) Errorf(format string, args ...interface{}) {
log.output(2, LogLevelError, fmt.Sprintf(format, args...))
}
func (log defaultLogger) output(callDepth int, level LogLevel, s string) {
if log.level < level {
return
}
var (
file string
line int
ok bool
)
_, file, line, ok = runtime.Caller(callDepth)
if !ok {
file = "???"
line = 0
}
index := strings.LastIndex(file, "/")
fmt.Printf("%s %s %s:%d %s\n", time.Now().Format("2006-01-02 15:04:05.000"), level.String(), file[index+1:], line, s)
}
+70
View File
@@ -0,0 +1,70 @@
package logger
import "sync/atomic"
// LogLevel represents the level at witch the logger will emit log messages
type LogLevel int32
// Set updates the LogLevel to the supplied value
func (ll *LogLevel) Set(level LogLevel) {
atomic.StoreInt32((*int32)(ll), int32(level))
}
// Get retrives the current LogLevel value
func (ll *LogLevel) Get() LogLevel {
return LogLevel(atomic.LoadInt32((*int32)(ll)))
}
func (ll LogLevel) String() string {
switch ll {
case LogLevelDisabled:
return "Disabled"
case LogLevelError:
return "Error"
case LogLevelWarn:
return "Warn"
case LogLevelInfo:
return "Info"
case LogLevelDebug:
return "Debug"
case LogLevelTrace:
return "Trace"
default:
return "Unknow"
}
}
const (
// LogLevelDisabled completely disables logging of any events
LogLevelDisabled LogLevel = iota
// LogLevelError is for fatal errors which should be handled by user code,
// but are logged to ensure that they are seen
LogLevelError
// LogLevelWarn is for logging abnormal, but non-fatal library operation
LogLevelWarn
// LogLevelInfo is for logging normal library operation (e.g. state transitions, etc.)
LogLevelInfo
// LogLevelDebug is for logging low-level library information (e.g. internal operations)
LogLevelDebug
// LogLevelTrace is for logging very low-level library information (e.g. network traces)
LogLevelTrace
)
// Logger interface
type Logger interface {
Trace(msg string)
Tracef(format string, args ...interface{})
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Warn(msg string)
Warnf(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// LoggerFactory interface
type LoggerFactory interface {
NewLogger(scope LogLevel) Logger
}
+126
View File
@@ -0,0 +1,126 @@
package aac
import (
"errors"
"io"
"github.com/H0RlZ0N/gortmppush/av"
)
type mpegExtension struct {
objectType byte
sampleRate byte
}
type mpegCfgInfo struct {
objectType byte
sampleRate byte
channel byte
sbr byte
ps byte
frameLen byte
exceptionLogTs int64
extension *mpegExtension
}
var aacRates = []int{96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350}
var (
specificBufInvalid = errors.New("audio mpegspecific error")
audioBufInvalid = errors.New("audiodata invalid")
)
const (
adtsHeaderLen = 7
)
type Parser struct {
gettedSpecific bool
adtsHeader []byte
cfgInfo *mpegCfgInfo
}
func NewParser() *Parser {
return &Parser{
gettedSpecific: false,
cfgInfo: &mpegCfgInfo{},
adtsHeader: make([]byte, adtsHeaderLen),
}
}
func (parser *Parser) specificInfo(src []byte) error {
if len(src) < 2 {
return specificBufInvalid
}
parser.gettedSpecific = true
parser.cfgInfo.objectType = (src[0] >> 3) & 0xff
parser.cfgInfo.sampleRate = ((src[0] & 0x07) << 1) | src[1]>>7
parser.cfgInfo.channel = (src[1] >> 3) & 0x0f
return nil
}
func (parser *Parser) adts(src []byte, w io.Writer) error {
if len(src) <= 0 || !parser.gettedSpecific {
return audioBufInvalid
}
frameLen := uint16(len(src)) + 7
//first write adts header
parser.adtsHeader[0] = 0xff
parser.adtsHeader[1] = 0xf1
parser.adtsHeader[2] &= 0x00
parser.adtsHeader[2] = parser.adtsHeader[2] | (parser.cfgInfo.objectType-1)<<6
parser.adtsHeader[2] = parser.adtsHeader[2] | (parser.cfgInfo.sampleRate)<<2
parser.adtsHeader[3] &= 0x00
parser.adtsHeader[3] = parser.adtsHeader[3] | (parser.cfgInfo.channel<<2)<<4
parser.adtsHeader[3] = parser.adtsHeader[3] | byte((frameLen<<3)>>14)
parser.adtsHeader[4] &= 0x00
parser.adtsHeader[4] = parser.adtsHeader[4] | byte((frameLen<<5)>>8)
parser.adtsHeader[5] &= 0x00
parser.adtsHeader[5] = parser.adtsHeader[5] | byte(((frameLen<<13)>>13)<<5)
parser.adtsHeader[5] = parser.adtsHeader[5] | (0x7C<<1)>>3
parser.adtsHeader[6] = 0xfc
if _, err := w.Write(parser.adtsHeader[0:]); err != nil {
return err
}
if _, err := w.Write(src); err != nil {
return err
}
return nil
}
func (parser *Parser) SampleRate() int {
rate := 44100
if parser.cfgInfo.sampleRate <= byte(len(aacRates)-1) {
rate = aacRates[parser.cfgInfo.sampleRate]
}
return rate
}
func (parser *Parser) Parse(b []byte, packetType uint8, w io.Writer) (err error) {
switch packetType {
case av.AAC_SEQHDR:
err = parser.specificInfo(b)
case av.AAC_RAW:
err = parser.adts(b, w)
}
return
}
// SpecificConfig comment
// 0000 0|000 0|000 0|000
func SpecificConfig(objectType, samplingFrequencyIndex, channelConfig uint8) []byte {
data := []byte{0x00, 0x00}
data[0] |= objectType << 3
data[0] |= samplingFrequencyIndex >> 1
data[1] |= samplingFrequencyIndex << 7
data[1] |= channelConfig << 3
return data
//return []byte{0x15, 0x88}
}
+67
View File
@@ -0,0 +1,67 @@
package parser
import (
"errors"
"fmt"
"io"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/media/aac"
"github.com/H0RlZ0N/gortmppush/media/h264"
"github.com/H0RlZ0N/gortmppush/media/mp3"
)
var (
errNoAudio = errors.New("demuxer no audio")
)
type CodecParser struct {
aac *aac.Parser
mp3 *mp3.Parser
h264 *h264.Parser
}
func NewCodecParser() *CodecParser {
return &CodecParser{}
}
func (codeParser *CodecParser) SampleRate() (int, error) {
if codeParser.aac == nil && codeParser.mp3 == nil {
return 0, errNoAudio
}
if codeParser.aac != nil {
return codeParser.aac.SampleRate(), nil
}
return codeParser.mp3.SampleRate(), nil
}
func (codeParser *CodecParser) Parse(p *av.Packet, w io.Writer) (err error) {
switch p.PacketType {
case av.PacketTypeVideo:
if p.VHeader.CodecID == av.VIDEO_H264 {
if codeParser.h264 == nil {
codeParser.h264 = h264.NewParser()
}
isSeq := p.VHeader.FrameType == av.FRAME_KEY && p.VHeader.AVCPacketType == av.AVC_SEQHDR
err = codeParser.h264.Parse(p.Data, isSeq, w)
}
case av.PacketTypeAudio:
switch p.AHeader.SoundFormat {
case av.SOUND_AAC:
if codeParser.aac == nil {
codeParser.aac = aac.NewParser()
}
err = codeParser.aac.Parse(p.Data, p.AHeader.AACPacketType, w)
case av.SOUND_MP3:
if codeParser.mp3 == nil {
codeParser.mp3 = mp3.NewParser()
}
err = codeParser.mp3.Parse(p.Data)
}
default:
err = fmt.Errorf("Unknow packet type:%d", p.PacketType)
}
return
}
+385
View File
@@ -0,0 +1,385 @@
package h264
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
)
const (
i_frame byte = 0
p_frame byte = 1
b_frame byte = 2
)
const (
nalu_type_not_define byte = 0
nalu_type_slice byte = 1 //slice_layer_without_partioning_rbsp() sliceheader
nalu_type_dpa byte = 2 // slice_data_partition_a_layer_rbsp( ), slice_header
nalu_type_dpb byte = 3 // slice_data_partition_b_layer_rbsp( )
nalu_type_dpc byte = 4 // slice_data_partition_c_layer_rbsp( )
nalu_type_idr byte = 5 // slice_layer_without_partitioning_rbsp( ),sliceheader
nalu_type_sei byte = 6 //sei_rbsp( )
nalu_type_sps byte = 7 //seq_parameter_set_rbsp( )
nalu_type_pps byte = 8 //pic_parameter_set_rbsp( )
nalu_type_aud byte = 9 // access_unit_delimiter_rbsp( )
nalu_type_eoesq byte = 10 //end_of_seq_rbsp( )
nalu_type_eostream byte = 11 //end_of_stream_rbsp( )
nalu_type_filler byte = 12 //filler_data_rbsp( )
)
const (
naluBytesLen int = 4
maxSpsPpsLen int = 2 * 1024
)
var (
decDataNil = errors.New("dec buf is nil")
spsDataError = errors.New("sps data error")
ppsHeaderError = errors.New("pps header error")
ppsDataError = errors.New("pps data error")
naluHeaderInvalid = errors.New("nalu header invalid")
videoDataInvalid = errors.New("video data not match")
dataSizeNotMatch = errors.New("data size not match")
naluBodyLenError = errors.New("nalu body len error")
)
var StartCode3 = []byte{0x00, 0x00, 0x01}
var StartCode4 = []byte{0x00, 0x00, 0x00, 0x01}
var naluAud = []byte{0x00, 0x00, 0x00, 0x01, 0x09, 0xf0}
type Parser struct {
frameType byte
specificInfo []byte
pps *bytes.Buffer
}
// AVC sequence header & AAC sequence header
/*
class AVCDecoderConfigurationRecord {
unsigned int(8) configurationVersion = 1;
unsigned int(8) AVCProfileIndication
unsigned int(8) profile_compatiblity;
unsigned int(8) AVCLevelIndication
reserved bit(6) `111111`b;
unsigned int(2) lengthSizeMinusOne;
reserved bit(3) `111`b;
unsigned int(5) numOfSequenceParameterSets;
for (i = 0; i < numOfSequenceParameterSets; i++) {
unsigned int(16) sequenceParameterSetLength;
bit (8 * sequenceParameterSetLength) sequenceParameterSetNALUnit;
}
unsigned int (8) numOfPictureParameterSet;
for(i=0; i < numOfPictureParameterSets; i++) {
unsigned int(16) pictureParameterSetLength;
bit (8 * pictureParameterSetLength) pictureParameterSetNALUnit;
}
}
*/
type sequenceHeader struct {
configVersion byte //8bits 固定值1
//Baseline ProfileBP 66 视频会议和移动应用
//Main Profile (MP, 77): 标清电视
//High Profile HiP 100): 高清电视
avcProfileIndication byte //8bits
profileCompatility byte //8bits
//H.264的Level,通过level可指定最大的图像分辨率,帧率
avcLevelIndication byte //8bits
reserved1 byte //6bits `111111`
naluLen byte //2bits
reserved2 byte //3bits
spsNum byte //5bits sps的数量,一般为1
ppsNum byte //8bits pps的数量,一般为1
spsLen int //16bits sps长度,后面跟随spsNALUint
ppsLen int //16bits pps长度,后面跟随ppsNALUint
}
func NewParser() *Parser {
return &Parser{
pps: bytes.NewBuffer(make([]byte, maxSpsPpsLen)),
}
}
// return value 1:sps, value2 :pps
func (parser *Parser) parseSpecificInfo(src []byte) error {
if len(src) < 9 {
return decDataNil
}
sps := []byte{}
pps := []byte{}
var seq sequenceHeader
seq.configVersion = src[0]
seq.avcProfileIndication = src[1]
seq.profileCompatility = src[2]
seq.avcLevelIndication = src[3]
seq.reserved1 = src[4] & 0xfc
seq.naluLen = src[4]&0x03 + 1
seq.reserved2 = src[5] >> 5
//get sps
seq.spsNum = src[5] & 0x1f
seq.spsLen = int(src[6])<<8 | int(src[7])
if len(src[8:]) < seq.spsLen || seq.spsLen <= 0 {
return spsDataError
}
sps = append(sps, StartCode4...)
sps = append(sps, src[8:(8+seq.spsLen)]...)
//get pps
tmpBuf := src[(8 + seq.spsLen):]
if len(tmpBuf) < 4 {
return ppsHeaderError
}
seq.ppsNum = tmpBuf[0]
seq.ppsLen = int(0)<<16 | int(tmpBuf[1])<<8 | int(tmpBuf[2])
if len(tmpBuf[3:]) < seq.ppsLen || seq.ppsLen <= 0 {
return ppsDataError
}
pps = append(pps, StartCode4...)
pps = append(pps, tmpBuf[3:]...)
parser.specificInfo = append(parser.specificInfo, sps...)
parser.specificInfo = append(parser.specificInfo, pps...)
return nil
}
func (parser *Parser) isNaluHeader(src []byte) bool {
if len(src) < naluBytesLen {
return false
}
return src[0] == 0x00 &&
src[1] == 0x00 &&
src[2] == 0x00 &&
src[3] == 0x01
}
func (parser *Parser) naluSize(src []byte) (int, error) {
if len(src) < naluBytesLen {
return 0, errors.New("nalusizedata invalid")
}
buf := src[:naluBytesLen]
size := int(0)
for i := 0; i < len(buf); i++ {
size = size<<8 + int(buf[i])
}
return size, nil
}
func (parser *Parser) getAnnexbH264(src []byte, w io.Writer) error {
dataSize := len(src)
if dataSize < naluBytesLen {
return videoDataInvalid
}
parser.pps.Reset()
_, err := w.Write(naluAud)
if err != nil {
return err
}
index := 0
nalLen := 0
hasSpsPps := false
hasWriteSpsPps := false
for dataSize > 0 {
nalLen, err = parser.naluSize(src[index:])
if err != nil {
return dataSizeNotMatch
}
index += naluBytesLen
dataSize -= naluBytesLen
if dataSize >= nalLen && len(src[index:]) >= nalLen && nalLen > 0 {
nalType := src[index] & 0x1f
switch nalType {
case nalu_type_aud:
case nalu_type_idr:
if !hasWriteSpsPps {
hasWriteSpsPps = true
if !hasSpsPps {
if _, err := w.Write(parser.specificInfo); err != nil {
return err
}
} else {
if _, err := w.Write(parser.pps.Bytes()); err != nil {
return err
}
}
}
fallthrough
case nalu_type_slice:
fallthrough
case nalu_type_sei:
_, err := w.Write(StartCode4)
if err != nil {
return err
}
_, err = w.Write(src[index : index+nalLen])
if err != nil {
return err
}
case nalu_type_sps:
fallthrough
case nalu_type_pps:
hasSpsPps = true
_, err := parser.pps.Write(StartCode4)
if err != nil {
return err
}
_, err = parser.pps.Write(src[index : index+nalLen])
if err != nil {
return err
}
}
index += nalLen
dataSize -= nalLen
} else {
return naluBodyLenError
}
}
return nil
}
func (parser *Parser) Parse(b []byte, isSeq bool, w io.Writer) (err error) {
switch isSeq {
case true:
err = parser.parseSpecificInfo(b)
case false:
// is annexb
if parser.isNaluHeader(b) {
_, err = w.Write(b)
} else {
err = parser.getAnnexbH264(b, w)
}
}
return
}
// ParseNalus 把src的数据按h264分隔符解析出来
func ParseNalus(src []byte) (nalus [][]byte) {
return ParseNalusN(src, -1)
}
// ParseNalusN 按照00 00 00 01 分割nalu单元 n == 0 返回nil < 0 全部分割 > 0 最多分割 n 部分
func ParseNalusN(src []byte, n int) (nalus [][]byte) {
nalus = make([][]byte, 0)
if len(src) < naluBytesLen {
fmt.Printf("invalid nalu len:%d\n", len(src))
return nalus
}
if n == 0 {
return nalus
}
pre := 0
index := 0
for {
//先判断是否n已经为1了,如果是1,就不要继续分割,直接break
//如果n > 1 或 < 0,继续分割
if n == 1 {
index = len(src)
break
}
bfind := false
startCodeLength := 0
if bytes.Compare(src[index:index+3], StartCode3) == 0 {
startCodeLength = 3
bfind = true
} else if bytes.Compare(src[index:index+4], StartCode4) == 0 {
startCodeLength = 4
bfind = true
}
if bfind {
if index > pre {
nalu := make([]byte, 0)
nalu = append(nalu, src[pre:index]...)
nalus = append(nalus, nalu)
if n > 0 {
n--
}
}
index += startCodeLength
pre = index
continue
}
index++
if index+4 > len(src) {
break
}
}
if index > pre && pre < len(src) {
nalu := make([]byte, 0)
nalu = append(nalu, src[pre:]...)
nalus = append(nalus, nalu)
}
return
}
// AVCDecoderConfigurationRecord 生成h264的sequence header
func AVCDecoderConfigurationRecord(sps, pps []byte) []byte {
//sps pps 去掉start code,否则可能会有问题 sps pps的start code是4位的
if bytes.Compare(sps, StartCode4) == 0 {
sps = sps[4:]
}
if bytes.Compare(pps, StartCode4) == 0 {
pps = pps[4:]
}
sHeader := sequenceHeader{
configVersion: 1, //固定为1
avcProfileIndication: 66, //标清电视
profileCompatility: 0, //todo
avcLevelIndication: 40, //todo
reserved1: 0x3f, // 0011 1111
naluLen: 3, //todo
reserved2: 0x07, //0000 0111
spsNum: 1, //一般都只有一个
ppsNum: 1, //一般都只有一个
spsLen: len(sps),
ppsLen: len(pps),
}
index := 0
buffer := make([]byte, 11+len(sps)+len(pps))
buffer[index] = byte(sHeader.configVersion)
index++ //0
buffer[index] = byte(sHeader.avcProfileIndication)
index++ //1
buffer[index] = byte(sHeader.profileCompatility)
index++ //2
buffer[index] = byte(sHeader.avcLevelIndication)
index++ //3
//4
buffer[index] = sHeader.reserved1 << 2
buffer[index] |= sHeader.naluLen & 0x03
index++
//5
buffer[index] = sHeader.reserved2 << 5
buffer[index] |= sHeader.spsNum & 0x1F
index++
//sps length
binary.BigEndian.PutUint16(buffer[6:], uint16(sHeader.spsLen))
index += 2
copy(buffer[index:], sps)
index += len(sps)
//pps number
buffer[index] = 1
index++
binary.BigEndian.PutUint16(buffer[index:], uint16(sHeader.ppsLen))
index += 2
copy(buffer[index:], pps)
index += len(pps)
return buffer
}
+41
View File
@@ -0,0 +1,41 @@
package mp3
import "errors"
type Parser struct {
samplingFrequency int
}
func NewParser() *Parser {
return &Parser{}
}
// sampling_frequency - indicates the sampling frequency, according to the following table.
// '00' 44.1 kHz
// '01' 48 kHz
// '10' 32 kHz
// '11' reserved
var mp3Rates = []int{44100, 48000, 32000}
var (
errMp3DataInvalid = errors.New("mp3data invalid")
errIndexInvalid = errors.New("invalid rate index")
)
func (parser *Parser) Parse(src []byte) error {
if len(src) < 3 {
return errMp3DataInvalid
}
index := (src[2] >> 2) & 0x3
if index <= byte(len(mp3Rates)-1) {
parser.samplingFrequency = mp3Rates[index]
return nil
}
return errIndexInvalid
}
func (parser *Parser) SampleRate() int {
if parser.samplingFrequency == 0 {
parser.samplingFrequency = 44100
}
return parser.samplingFrequency
}
+50
View File
@@ -0,0 +1,50 @@
package amf
import (
"errors"
"fmt"
"io"
)
func (d *Decoder) DecodeBatch(r io.Reader, ver Version) (ret []interface{}, err error) {
var v interface{}
for {
v, err = d.Decode(r, ver)
if err != nil {
break
}
ret = append(ret, v)
}
return
}
func (d *Decoder) Decode(r io.Reader, ver Version) (interface{}, error) {
switch ver {
case 0:
return d.DecodeAmf0(r)
case 3:
return d.DecodeAmf3(r)
}
return nil, errors.New(fmt.Sprintf("decode amf: unsupported version %d", ver))
}
func (e *Encoder) EncodeBatch(w io.Writer, ver Version, val ...interface{}) (int, error) {
for _, v := range val {
if _, err := e.Encode(w, v, ver); err != nil {
return 0, err
}
}
return 0, nil
}
func (e *Encoder) Encode(w io.Writer, val interface{}, ver Version) (int, error) {
switch ver {
case AMF0:
return e.EncodeAmf0(w, val)
case AMF3:
return e.EncodeAmf3(w, val)
}
return 0, Error("encode amf: unsupported version %d", ver)
}
+206
View File
@@ -0,0 +1,206 @@
package amf
import (
"bytes"
"errors"
"fmt"
"reflect"
"testing"
"time"
)
func EncodeAndDecode(val interface{}, ver Version) (result interface{}, err error) {
enc := new(Encoder)
dec := new(Decoder)
buf := new(bytes.Buffer)
_, err = enc.Encode(buf, val, ver)
if err != nil {
return nil, errors.New(fmt.Sprintf("error in encode: %s", err))
}
result, err = dec.Decode(buf, ver)
if err != nil {
return nil, errors.New(fmt.Sprintf("error in decode: %s", err))
}
return
}
func Compare(val interface{}, ver Version, name string, t *testing.T) {
result, err := EncodeAndDecode(val, ver)
if err != nil {
t.Errorf("%s: %s", name, err)
}
if !reflect.DeepEqual(val, result) {
val_v := reflect.ValueOf(val)
result_v := reflect.ValueOf(result)
t.Errorf("%s: comparison failed between %+v (%s) and %+v (%s)", name, val, val_v.Type(), result, result_v.Type())
Dump("expected", val)
Dump("got", result)
}
// if val != result {
// t.Errorf("%s: comparison failed between %+v and %+v", name, val, result)
// }
}
func TestAmf0Number(t *testing.T) {
Compare(float64(3.14159), 0, "amf0 number float", t)
Compare(float64(124567890), 0, "amf0 number high", t)
Compare(float64(-34.2), 0, "amf0 number negative", t)
}
func TestAmf0String(t *testing.T) {
Compare("a pup!", 0, "amf0 string simple", t)
Compare("日本語", 0, "amf0 string utf8", t)
}
func TestAmf0Boolean(t *testing.T) {
Compare(true, 0, "amf0 boolean true", t)
Compare(false, 0, "amf0 boolean false", t)
}
func TestAmf0Null(t *testing.T) {
Compare(nil, 0, "amf0 boolean nil", t)
}
func TestAmf0Object(t *testing.T) {
obj := make(Object)
obj["dog"] = "alfie"
obj["coffee"] = true
obj["drugs"] = false
obj["pi"] = 3.14159
res, err := EncodeAndDecode(obj, 0)
if err != nil {
t.Errorf("amf0 object: %s", err)
}
result, ok := res.(Object)
if ok != true {
t.Errorf("amf0 object conversion failed")
}
if result["dog"] != "alfie" {
t.Errorf("amf0 object string: comparison failed")
}
if result["coffee"] != true {
t.Errorf("amf0 object true: comparison failed")
}
if result["drugs"] != false {
t.Errorf("amf0 object false: comparison failed")
}
if result["pi"] != float64(3.14159) {
t.Errorf("amf0 object float: comparison failed")
}
}
func TestAmf0Array(t *testing.T) {
arr := [5]float64{1, 2, 3, 4, 5}
res, err := EncodeAndDecode(arr, 0)
if err != nil {
t.Error("amf0 object: %s", err)
}
result, ok := res.(Array)
if ok != true {
t.Errorf("amf0 array conversion failed")
}
for i := 0; i < len(arr); i++ {
if arr[i] != result[i] {
t.Errorf("amf0 array %d comparison failed: %v / %v", i, arr[i], result[i])
}
}
}
func TestAmf3Integer(t *testing.T) {
Compare(int32(0), 3, "amf3 integer zero", t)
Compare(int32(1245), 3, "amf3 integer low", t)
Compare(int32(123456), 3, "amf3 integer high", t)
}
func TestAmf3Double(t *testing.T) {
Compare(float64(3.14159), 3, "amf3 double float", t)
Compare(float64(1234567890), 3, "amf3 double high", t)
Compare(float64(-12345), 3, "amf3 double negative", t)
}
func TestAmf3String(t *testing.T) {
Compare("a pup!", 0, "amf0 string simple", t)
Compare("日本語", 0, "amf0 string utf8", t)
}
func TestAmf3Boolean(t *testing.T) {
Compare(true, 3, "amf3 boolean true", t)
Compare(false, 3, "amf3 boolean false", t)
}
func TestAmf3Null(t *testing.T) {
Compare(nil, 3, "amf3 boolean nil", t)
}
func TestAmf3Date(t *testing.T) {
t1 := time.Unix(time.Now().Unix(), 0).UTC() // nanoseconds discarded
t2 := time.Date(1983, 9, 4, 12, 4, 8, 0, time.UTC)
Compare(t1, 3, "amf3 date now", t)
Compare(t2, 3, "amf3 date earlier", t)
}
func TestAmf3Array(t *testing.T) {
obj := make(Object)
obj["key"] = "val"
var arr Array
arr = append(arr, "amf")
arr = append(arr, float64(2))
arr = append(arr, -34.95)
arr = append(arr, true)
arr = append(arr, false)
res, err := EncodeAndDecode(arr, 3)
if err != nil {
t.Error("amf3 object: %s", err)
}
result, ok := res.(Array)
if ok != true {
t.Errorf("amf3 array conversion failed: %+v", res)
}
for i := 0; i < len(arr); i++ {
if arr[i] != result[i] {
t.Errorf("amf3 array %d comparison failed: %v / %v", i, arr[i], result[i])
}
}
}
func TestAmf3ByteArray(t *testing.T) {
enc := new(Encoder)
dec := new(Decoder)
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x00}
enc.EncodeAmf3ByteArray(buf, expect, true)
result, err := dec.DecodeAmf3ByteArray(buf, true)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(result, expect) != 0 {
t.Errorf("expected: %+v, got %+v", expect, buf)
}
}
+105
View File
@@ -0,0 +1,105 @@
package amf
import (
"io"
)
const (
AMF0 = 0x00
AMF3 = 0x03
)
const (
AMF0_NUMBER_MARKER = 0x00
AMF0_BOOLEAN_MARKER = 0x01
AMF0_STRING_MARKER = 0x02
AMF0_OBJECT_MARKER = 0x03
AMF0_MOVIECLIP_MARKER = 0x04
AMF0_NULL_MARKER = 0x05
AMF0_UNDEFINED_MARKER = 0x06
AMF0_REFERENCE_MARKER = 0x07
AMF0_ECMA_ARRAY_MARKER = 0x08
AMF0_OBJECT_END_MARKER = 0x09
AMF0_STRICT_ARRAY_MARKER = 0x0a
AMF0_DATE_MARKER = 0x0b
AMF0_LONG_STRING_MARKER = 0x0c
AMF0_UNSUPPORTED_MARKER = 0x0d
AMF0_RECORDSET_MARKER = 0x0e
AMF0_XML_DOCUMENT_MARKER = 0x0f
AMF0_TYPED_OBJECT_MARKER = 0x10
AMF0_ACMPLUS_OBJECT_MARKER = 0x11
)
const (
AMF0_BOOLEAN_FALSE = 0x00
AMF0_BOOLEAN_TRUE = 0x01
AMF0_STRING_MAX = 65535
AMF3_INTEGER_MAX = 536870911
)
const (
AMF3_UNDEFINED_MARKER = 0x00
AMF3_NULL_MARKER = 0x01
AMF3_FALSE_MARKER = 0x02
AMF3_TRUE_MARKER = 0x03
AMF3_INTEGER_MARKER = 0x04
AMF3_DOUBLE_MARKER = 0x05
AMF3_STRING_MARKER = 0x06
AMF3_XMLDOC_MARKER = 0x07
AMF3_DATE_MARKER = 0x08
AMF3_ARRAY_MARKER = 0x09
AMF3_OBJECT_MARKER = 0x0a
AMF3_XMLSTRING_MARKER = 0x0b
AMF3_BYTEARRAY_MARKER = 0x0c
)
type ExternalHandler func(*Decoder, io.Reader) (interface{}, error)
type Decoder struct {
refCache []interface{}
stringRefs []string
objectRefs []interface{}
traitRefs []Trait
externalHandlers map[string]ExternalHandler
}
func NewDecoder() *Decoder {
return &Decoder{
externalHandlers: make(map[string]ExternalHandler),
}
}
func (d *Decoder) RegisterExternalHandler(name string, f ExternalHandler) {
d.externalHandlers[name] = f
}
type Encoder struct {
}
type Version uint8
type Array []interface{}
type Object map[string]interface{}
type TypedObject struct {
Type string
Object Object
}
type Trait struct {
Type string
Externalizable bool
Dynamic bool
Properties []string
}
func NewTrait() *Trait {
return &Trait{}
}
func NewTypedObject() *TypedObject {
return &TypedObject{
Type: "",
Object: make(Object),
}
}
+341
View File
@@ -0,0 +1,341 @@
package amf
import (
"encoding/binary"
"io"
)
// amf0 polymorphic router
func (d *Decoder) DecodeAmf0(r io.Reader) (interface{}, error) {
marker, err := ReadMarker(r)
if err != nil {
return nil, err
}
switch marker {
case AMF0_NUMBER_MARKER:
return d.DecodeAmf0Number(r, false)
case AMF0_BOOLEAN_MARKER:
return d.DecodeAmf0Boolean(r, false)
case AMF0_STRING_MARKER:
return d.DecodeAmf0String(r, false)
case AMF0_OBJECT_MARKER:
return d.DecodeAmf0Object(r, false)
case AMF0_MOVIECLIP_MARKER:
return nil, Error("decode amf0: unsupported type movieclip")
case AMF0_NULL_MARKER:
return d.DecodeAmf0Null(r, false)
case AMF0_UNDEFINED_MARKER:
return d.DecodeAmf0Undefined(r, false)
case AMF0_REFERENCE_MARKER:
return nil, Error("decode amf0: unsupported type reference")
case AMF0_ECMA_ARRAY_MARKER:
return d.DecodeAmf0EcmaArray(r, false)
case AMF0_STRICT_ARRAY_MARKER:
return d.DecodeAmf0StrictArray(r, false)
case AMF0_DATE_MARKER:
return d.DecodeAmf0Date(r, false)
case AMF0_LONG_STRING_MARKER:
return d.DecodeAmf0LongString(r, false)
case AMF0_UNSUPPORTED_MARKER:
return d.DecodeAmf0Unsupported(r, false)
case AMF0_RECORDSET_MARKER:
return nil, Error("decode amf0: unsupported type recordset")
case AMF0_XML_DOCUMENT_MARKER:
return d.DecodeAmf0XmlDocument(r, false)
case AMF0_TYPED_OBJECT_MARKER:
return d.DecodeAmf0TypedObject(r, false)
case AMF0_ACMPLUS_OBJECT_MARKER:
return d.DecodeAmf3(r)
}
return nil, Error("decode amf0: unsupported type %d", marker)
}
// marker: 1 byte 0x00
// format: 8 byte big endian float64
func (d *Decoder) DecodeAmf0Number(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_NUMBER_MARKER); err != nil {
return
}
err = binary.Read(r, binary.BigEndian, &result)
if err != nil {
return float64(0), Error("amf0 decode: unable to read number: %s", err)
}
return
}
// marker: 1 byte 0x01
// format: 1 byte, 0x00 = false, 0x01 = true
func (d *Decoder) DecodeAmf0Boolean(r io.Reader, decodeMarker bool) (result bool, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_BOOLEAN_MARKER); err != nil {
return
}
var b byte
if b, err = ReadByte(r); err != nil {
return
}
if b == AMF0_BOOLEAN_FALSE {
return false, nil
} else if b == AMF0_BOOLEAN_TRUE {
return true, nil
}
return false, Error("decode amf0: unexpected value %v for boolean", b)
}
// marker: 1 byte 0x02
// format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0String(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_STRING_MARKER); err != nil {
return
}
var length uint16
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return "", Error("decode amf0: unable to decode string length: %s", err)
}
//if string length is zero, no readbytes
if length == 0 {
return "", nil
}
var bytes = make([]byte, length)
if bytes, err = ReadBytes(r, int(length)); err != nil {
return "", Error("decode amf0: unable to decode string value: %s", err)
}
return string(bytes), nil
}
// marker: 1 byte 0x03
// format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0Object(r io.Reader, decodeMarker bool) (Object, error) {
if err := AssertMarker(r, decodeMarker, AMF0_OBJECT_MARKER); err != nil {
return nil, err
}
result := make(Object)
d.refCache = append(d.refCache, result)
for {
key, err := d.DecodeAmf0String(r, false)
if err != nil {
return nil, err
}
if key == "" {
if err = AssertMarker(r, true, AMF0_OBJECT_END_MARKER); err != nil {
return nil, Error("decode amf0: expected object end marker: %s", err)
}
break
}
value, err := d.DecodeAmf0(r)
if err != nil {
return nil, Error("decode amf0: unable to decode object value: %s", err)
}
result[key] = value
}
return result, nil
}
// marker: 1 byte 0x05
// no additional data
func (d *Decoder) DecodeAmf0Null(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_NULL_MARKER)
return
}
// marker: 1 byte 0x06
// no additional data
func (d *Decoder) DecodeAmf0Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_UNDEFINED_MARKER)
return
}
// marker: 1 byte 0x07
// format: 2 byte big endian uint16
/*
func (d *Decoder) DecodeAmf0Reference(r io.Reader, decodeMarker bool) (interface{}, error) {
if err := AssertMarker(r, decodeMarker, AMF0_REFERENCE_MARKER); err != nil {
return nil, err
}
var err error
var ref uint16
err = binary.Read(r, binary.BigEndian, &ref)
if err != nil {
return nil, Error("decode amf0: unable to decode reference id: %s", err)
}
if int(ref) > len(d.refCache) {
return nil, Error("decode amf0: bad reference %d (current length %d)", ref, len(d.refCache))
}
result := d.refCache[ref]
return result, nil
}
*/
// marker: 1 byte 0x08
// format:
// - 4 byte big endian uint32 with length of associative array
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0EcmaArray(r io.Reader, decodeMarker bool) (Object, error) {
if err := AssertMarker(r, decodeMarker, AMF0_ECMA_ARRAY_MARKER); err != nil {
return nil, err
}
var length uint32
err := binary.Read(r, binary.BigEndian, &length)
result, err := d.DecodeAmf0Object(r, false)
if err != nil {
return nil, Error("decode amf0: unable to decode ecma array object: %s", err)
}
return result, nil
}
// marker: 1 byte 0x0a
// format:
// - 4 byte big endian uint32 to determine length of associative array
// - n (length) encoded values
func (d *Decoder) DecodeAmf0StrictArray(r io.Reader, decodeMarker bool) (result Array, err error) {
if err := AssertMarker(r, decodeMarker, AMF0_STRICT_ARRAY_MARKER); err != nil {
return nil, err
}
var length uint32
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return nil, Error("decode amf0: unable to decode strict array length: %s", err)
}
d.refCache = append(d.refCache, result)
for i := uint32(0); i < length; i++ {
tmp, err := d.DecodeAmf0(r)
if err != nil {
return nil, Error("decode amf0: unable to decode strict array object: %s", err)
}
result = append(result, tmp)
}
return result, nil
}
// marker: 1 byte 0x0b
// format:
// - normal number format:
// - 8 byte big endian float64
//
// - 2 byte unused
func (d *Decoder) DecodeAmf0Date(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_DATE_MARKER); err != nil {
return
}
if result, err = d.DecodeAmf0Number(r, false); err != nil {
return float64(0), Error("decode amf0: unable to decode float in date: %s", err)
}
if _, err = ReadBytes(r, 2); err != nil {
return float64(0), Error("decode amf0: unable to read 2 trail bytes in date: %s", err)
}
return
}
// marker: 1 byte 0x0c
// format:
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0LongString(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_LONG_STRING_MARKER); err != nil {
return
}
var length uint32
err = binary.Read(r, binary.BigEndian, &length)
if err != nil {
return "", Error("decode amf0: unable to decode long string length: %s", err)
}
var bytes = make([]byte, length)
if bytes, err = ReadBytes(r, int(length)); err != nil {
return "", Error("decode amf0: unable to decode long string value: %s", err)
}
return string(bytes), nil
}
// marker: 1 byte 0x0d
// no additional data
func (d *Decoder) DecodeAmf0Unsupported(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF0_UNSUPPORTED_MARKER)
return
}
// marker: 1 byte 0x0f
// format:
// - normal long string format
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (d *Decoder) DecodeAmf0XmlDocument(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF0_XML_DOCUMENT_MARKER); err != nil {
return
}
return d.DecodeAmf0LongString(r, false)
}
// marker: 1 byte 0x10
// format:
// - normal string format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
//
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (d *Decoder) DecodeAmf0TypedObject(r io.Reader, decodeMarker bool) (TypedObject, error) {
result := *new(TypedObject)
err := AssertMarker(r, decodeMarker, AMF0_TYPED_OBJECT_MARKER)
if err != nil {
return result, err
}
d.refCache = append(d.refCache, result)
result.Type, err = d.DecodeAmf0String(r, false)
if err != nil {
return result, Error("decode amf0: typed object unable to determine type: %s", err)
}
result.Object, err = d.DecodeAmf0Object(r, false)
if err != nil {
return result, Error("decode amf0: typed object unable to determine object: %s", err)
}
return result, nil
}
+588
View File
@@ -0,0 +1,588 @@
package amf
import (
"bytes"
"testing"
)
func TestDecodeAmf0Number(t *testing.T) {
buf := bytes.NewReader([]byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33})
expect := float64(1.2)
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test number interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Number(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test number interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Number(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0BooleanTrue(t *testing.T) {
buf := bytes.NewReader([]byte{0x01, 0x01})
expect := true
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Boolean(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Boolean(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0BooleanFalse(t *testing.T) {
buf := bytes.NewReader([]byte{0x01, 0x00})
expect := false
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Boolean(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test boolean interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Boolean(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0String(t *testing.T) {
buf := bytes.NewReader([]byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0String(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0String(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0Object(t *testing.T) {
buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test object interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Object(buf, true)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test object interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Object(buf, false)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
}
func TestDecodeAmf0Null(t *testing.T) {
buf := bytes.NewReader([]byte{0x05})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test null interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Null(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf0Undefined(t *testing.T) {
buf := bytes.NewReader([]byte{0x06})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test undefined interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Undefined(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
/*
func TestDecodeReference(t *testing.T) {
buf := bytes.NewReader([]byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x00, 0x00, 0x00, 0x00, 0x09})
dec := &Decoder{}
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
_, ok2 := obj["foo"].(Object)
if ok2 != true {
t.Errorf("expected foo value to cast to object")
}
}
*/
func TestDecodeAmf0EcmaArray(t *testing.T) {
buf := bytes.NewReader([]byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
obj, ok := got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test ecma array interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0EcmaArray(buf, true)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to object")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
// Test ecma array interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0EcmaArray(buf, false)
if err != nil {
t.Errorf("%s", err)
}
obj, ok = got.(Object)
if ok != true {
t.Errorf("expected result to cast to ecma array")
}
if obj["foo"] != "bar" {
t.Errorf("expected {'foo'='bar'}, got %v", obj)
}
}
func TestDecodeAmf0StrictArray(t *testing.T) {
buf := bytes.NewReader([]byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
arr, ok := got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
// Test strict array interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0StrictArray(buf, true)
if err != nil {
t.Errorf("%s", err)
}
arr, ok = got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
// Test strict array interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0StrictArray(buf, false)
if err != nil {
t.Errorf("%s", err)
}
arr, ok = got.(Array)
if ok != true {
t.Errorf("expected result to cast to strict array")
}
if arr[0] != float64(5) {
t.Errorf("expected array[0] to be 5, got %v", arr[0])
}
if arr[1] != "foo" {
t.Errorf("expected array[1] to be 'foo', got %v", arr[1])
}
if arr[2] != nil {
t.Errorf("expected array[2] to be nil, got %v", arr[2])
}
}
func TestDecodeAmf0Date(t *testing.T) {
buf := bytes.NewReader([]byte{0x0b, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
expect := float64(5)
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test date interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Date(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test date interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0Date(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0LongString(t *testing.T) {
buf := bytes.NewReader([]byte{0x0c, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0LongString(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0LongString(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0Unsupported(t *testing.T) {
buf := bytes.NewReader([]byte{0x0d})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
// Test unsupported interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0Unsupported(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf0XmlDocument(t *testing.T) {
buf := bytes.NewReader([]byte{0x0f, 0x00, 0x00, 0x00, 0x03, 0x66, 0x6f, 0x6f})
expect := "foo"
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0XmlDocument(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
// Test long string interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0XmlDocument(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf0TypedObject(t *testing.T) {
buf := bytes.NewReader([]byte{
0x10, 0x00, 0x0F, 'o', 'r', 'g',
'.', 'a', 'm', 'f', '.', 'A',
'S', 'C', 'l', 'a', 's', 's',
0x00, 0x03, 'b', 'a', 'z', 0x05,
0x00, 0x03, 'f', 'o', 'o', 0x02,
0x00, 0x03, 'b', 'a', 'r', 0x00,
0x00, 0x09,
})
dec := &Decoder{}
// Test main interface
got, err := dec.DecodeAmf0(buf)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok := got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
// Test typed object interface with marker
buf.Seek(0, 0)
got, err = dec.DecodeAmf0TypedObject(buf, true)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok = got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
// Test typed object interface without marker
buf.Seek(1, 0)
got, err = dec.DecodeAmf0TypedObject(buf, false)
if err != nil {
t.Errorf("%s", err)
}
tobj, ok = got.(TypedObject)
if ok != true {
t.Errorf("expected result to cast to typed object, got %+v", tobj)
}
if tobj.Type != "org.amf.ASClass" {
t.Errorf("expected typed object type to be 'class', got %v", tobj.Type)
}
if tobj.Object["foo"] != "bar" {
t.Errorf("expected typed object object foo to eql bar, got %v", tobj.Object["foo"])
}
if tobj.Object["baz"] != nil {
t.Errorf("expected typed object object baz to nil, got %v", tobj.Object["baz"])
}
}
+496
View File
@@ -0,0 +1,496 @@
package amf
import (
"encoding/binary"
"io"
"time"
)
// amf3 polymorphic router
func (d *Decoder) DecodeAmf3(r io.Reader) (interface{}, error) {
marker, err := ReadMarker(r)
if err != nil {
return nil, err
}
switch marker {
case AMF3_UNDEFINED_MARKER:
return d.DecodeAmf3Undefined(r, false)
case AMF3_NULL_MARKER:
return d.DecodeAmf3Null(r, false)
case AMF3_FALSE_MARKER:
return d.DecodeAmf3False(r, false)
case AMF3_TRUE_MARKER:
return d.DecodeAmf3True(r, false)
case AMF3_INTEGER_MARKER:
return d.DecodeAmf3Integer(r, false)
case AMF3_DOUBLE_MARKER:
return d.DecodeAmf3Double(r, false)
case AMF3_STRING_MARKER:
return d.DecodeAmf3String(r, false)
case AMF3_XMLDOC_MARKER:
return d.DecodeAmf3Xml(r, false)
case AMF3_DATE_MARKER:
return d.DecodeAmf3Date(r, false)
case AMF3_ARRAY_MARKER:
return d.DecodeAmf3Array(r, false)
case AMF3_OBJECT_MARKER:
return d.DecodeAmf3Object(r, false)
case AMF3_XMLSTRING_MARKER:
return d.DecodeAmf3Xml(r, false)
case AMF3_BYTEARRAY_MARKER:
return d.DecodeAmf3ByteArray(r, false)
}
return nil, Error("decode amf3: unsupported type %d", marker)
}
// marker: 1 byte 0x00
// no additional data
func (d *Decoder) DecodeAmf3Undefined(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF3_UNDEFINED_MARKER)
return
}
// marker: 1 byte 0x01
// no additional data
func (d *Decoder) DecodeAmf3Null(r io.Reader, decodeMarker bool) (result interface{}, err error) {
err = AssertMarker(r, decodeMarker, AMF3_NULL_MARKER)
return
}
// marker: 1 byte 0x02
// no additional data
func (d *Decoder) DecodeAmf3False(r io.Reader, decodeMarker bool) (result bool, err error) {
err = AssertMarker(r, decodeMarker, AMF3_FALSE_MARKER)
result = false
return
}
// marker: 1 byte 0x03
// no additional data
func (d *Decoder) DecodeAmf3True(r io.Reader, decodeMarker bool) (result bool, err error) {
err = AssertMarker(r, decodeMarker, AMF3_TRUE_MARKER)
result = true
return
}
// marker: 1 byte 0x04
func (d *Decoder) DecodeAmf3Integer(r io.Reader, decodeMarker bool) (result int32, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_INTEGER_MARKER); err != nil {
return
}
var u29 uint32
u29, err = d.decodeU29(r)
if err != nil {
return
}
result = int32(u29)
if result > 0xfffffff {
result = int32(u29 - 0x20000000)
}
return
}
// marker: 1 byte 0x05
func (d *Decoder) DecodeAmf3Double(r io.Reader, decodeMarker bool) (result float64, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_DOUBLE_MARKER); err != nil {
return
}
err = binary.Read(r, binary.BigEndian, &result)
if err != nil {
return float64(0), Error("amf3 decode: unable to read double: %s", err)
}
return
}
// marker: 1 byte 0x06
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (d *Decoder) DecodeAmf3String(r io.Reader, decodeMarker bool) (result string, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_STRING_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return "", Error("amf3 decode: unable to decode string reference and length: %s", err)
}
if isRef {
result = d.stringRefs[refVal]
return
}
buf := make([]byte, refVal)
_, err = r.Read(buf)
if err != nil {
return "", Error("amf3 decode: unable to read string: %s", err)
}
result = string(buf)
if result != "" {
d.stringRefs = append(d.stringRefs, result)
}
return
}
// marker: 1 byte 0x08
// format:
// - u29 reference int, if reference, no more data
// - timestamp double
func (d *Decoder) DecodeAmf3Date(r io.Reader, decodeMarker bool) (result time.Time, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_DATE_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode date reference and length: %s", err)
}
if isRef {
res, ok := d.objectRefs[refVal].(time.Time)
if ok != true {
return result, Error("amf3 decode: unable to extract time from date object references")
}
return res, err
}
var u64 float64
err = binary.Read(r, binary.BigEndian, &u64)
if err != nil {
return result, Error("amf3 decode: unable to read double: %s", err)
}
result = time.Unix(int64(u64/1000), 0).UTC()
d.objectRefs = append(d.objectRefs, result)
return
}
// marker: 1 byte 0x09
// format:
// - u29 reference int. if reference, no more data.
// - string representing associative array if present
// - n values (length of u29)
func (d *Decoder) DecodeAmf3Array(r io.Reader, decodeMarker bool) (result Array, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_ARRAY_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode array reference and length: %s", err)
}
if isRef {
objRefId := refVal >> 1
res, ok := d.objectRefs[objRefId].(Array)
if ok != true {
return result, Error("amf3 decode: unable to extract array from object references")
}
return res, err
}
var key string
key, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read key for array: %s", err)
}
if key != "" {
return result, Error("amf3 decode: array key is not empty, can't handle associative array")
}
for i := uint32(0); i < refVal; i++ {
tmp, err := d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: array element could not be decoded: %s", err)
}
result = append(result, tmp)
}
d.objectRefs = append(d.objectRefs, result)
return
}
// marker: 1 byte 0x09
// format: oh dear god
func (d *Decoder) DecodeAmf3Object(r io.Reader, decodeMarker bool) (result interface{}, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_OBJECT_MARKER); err != nil {
return nil, err
}
// decode the initial u29
isRef, refVal, err := d.decodeReferenceInt(r)
if err != nil {
return nil, Error("amf3 decode: unable to decode object reference and length: %s", err)
}
// if this is a object reference only, grab it and return it
if isRef {
objRefId := refVal >> 1
return d.objectRefs[objRefId], nil
}
// each type has traits that are cached, if the peer sent a reference
// then we'll need to look it up and use it.
var trait Trait
traitIsRef := (refVal & 0x01) == 0
if traitIsRef {
traitRef := refVal >> 1
trait = d.traitRefs[traitRef]
} else {
// build a new trait from what's left of the given u29
trait = *NewTrait()
trait.Externalizable = (refVal & 0x02) != 0
trait.Dynamic = (refVal & 0x04) != 0
var cls string
cls, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read trait type for object: %s", err)
}
trait.Type = cls
// traits have property keys, encoded as amf3 strings
propLength := refVal >> 3
for i := uint32(0); i < propLength; i++ {
tmp, err := d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to read trait property for object: %s", err)
}
trait.Properties = append(trait.Properties, tmp)
}
d.traitRefs = append(d.traitRefs, trait)
}
d.objectRefs = append(d.objectRefs, result)
// objects can be externalizable, meaning that the system has no concrete understanding of
// their properties or how they are encoded. in that case, we need to find and delegate behavior
// to the right object.
if trait.Externalizable {
switch trait.Type {
case "DSA": // AsyncMessageExt
result, err = d.decodeAsyncMessageExt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dsa: %s", err)
}
case "DSK": // AcknowledgeMessageExt
result, err = d.decodeAcknowledgeMessageExt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dsk: %s", err)
}
case "flex.messaging.io.ArrayCollection":
result, err = d.decodeArrayCollection(r)
if err != nil {
return result, Error("amf3 decode: unable to decode ac: %s", err)
}
// store an extra reference to array collection container
d.objectRefs = append(d.objectRefs, result)
default:
fn, ok := d.externalHandlers[trait.Type]
if ok {
result, err = fn(d, r)
if err != nil {
return result, Error("amf3 decode: unable to call external decoder for type %s: %s", trait.Type, err)
}
} else {
return result, Error("amf3 decode: unable to decode external type %s, no handler", trait.Type)
}
}
return result, err
}
var key string
var val interface{}
var obj Object
obj = make(Object)
// non-externalizable objects have property keys in traits, iterate through them
// and add the read values to the object
for _, key = range trait.Properties {
val, err = d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: unable to decode object property: %s", err)
}
obj[key] = val
}
// if an object is dynamic, it can have extra key/value data at the end. in this case,
// read keys until we get an empty one.
if trait.Dynamic {
for {
key, err = d.DecodeAmf3String(r, false)
if err != nil {
return result, Error("amf3 decode: unable to decode dynamic key: %s", err)
}
if key == "" {
break
}
val, err = d.DecodeAmf3(r)
if err != nil {
return result, Error("amf3 decode: unable to decode dynamic value: %s", err)
}
obj[key] = val
}
}
result = obj
return
}
// marker: 1 byte 0x07 or 0x0b
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (d *Decoder) DecodeAmf3Xml(r io.Reader, decodeMarker bool) (result string, err error) {
if decodeMarker {
var marker byte
marker, err = ReadMarker(r)
if err != nil {
return "", err
}
if (marker != AMF3_XMLDOC_MARKER) && (marker != AMF3_XMLSTRING_MARKER) {
return "", Error("decode assert marker failed: expected %v or %v, got %v", AMF3_XMLDOC_MARKER, AMF3_XMLSTRING_MARKER, marker)
}
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return "", Error("amf3 decode: unable to decode xml reference and length: %s", err)
}
if isRef {
var ok bool
buf := d.objectRefs[refVal]
result, ok = buf.(string)
if ok != true {
return "", Error("amf3 decode: cannot coerce object reference into xml string")
}
return
}
buf := make([]byte, refVal)
_, err = r.Read(buf)
if err != nil {
return "", Error("amf3 decode: unable to read xml string: %s", err)
}
result = string(buf)
if result != "" {
d.objectRefs = append(d.objectRefs, result)
}
return
}
// marker: 1 byte 0x0c
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read.
func (d *Decoder) DecodeAmf3ByteArray(r io.Reader, decodeMarker bool) (result []byte, err error) {
if err = AssertMarker(r, decodeMarker, AMF3_BYTEARRAY_MARKER); err != nil {
return
}
var isRef bool
var refVal uint32
isRef, refVal, err = d.decodeReferenceInt(r)
if err != nil {
return result, Error("amf3 decode: unable to decode byte array reference and length: %s", err)
}
if isRef {
var ok bool
result, ok = d.objectRefs[refVal].([]byte)
if ok != true {
return result, Error("amf3 decode: unable to convert object ref to bytes")
}
return
}
result = make([]byte, refVal)
_, err = r.Read(result)
if err != nil {
return result, Error("amf3 decode: unable to read bytearray: %s", err)
}
d.objectRefs = append(d.objectRefs, result)
return
}
func (d *Decoder) decodeU29(r io.Reader) (result uint32, err error) {
var b byte
for i := 0; i < 3; i++ {
b, err = ReadByte(r)
if err != nil {
return
}
result = (result << 7) + uint32(b&0x7F)
if (b & 0x80) == 0 {
return
}
}
b, err = ReadByte(r)
if err != nil {
return
}
result = ((result << 8) + uint32(b))
return
}
func (d *Decoder) decodeReferenceInt(r io.Reader) (isRef bool, refVal uint32, err error) {
u29, err := d.decodeU29(r)
if err != nil {
return false, 0, Error("amf3 decode: unable to decode reference int: %s", err)
}
isRef = u29&0x01 == 0
refVal = u29 >> 1
return
}
+127
View File
@@ -0,0 +1,127 @@
package amf
import (
"fmt"
"io"
"math"
)
// Abstract external boilerplate
func (d *Decoder) decodeAbstractMessage(r io.Reader) (result Object, err error) {
result = make(Object)
if err = d.decodeExternal(r, &result,
[]string{"body", "clientId", "destination", "headers", "messageId", "timeStamp", "timeToLive"},
[]string{"clientIdBytes", "messageIdBytes"}); err != nil {
return result, Error("unable to decode abstract external: %s", err)
}
return
}
// DSA
func (d *Decoder) decodeAsyncMessageExt(r io.Reader) (result Object, err error) {
return d.decodeAsyncMessage(r)
}
func (d *Decoder) decodeAsyncMessage(r io.Reader) (result Object, err error) {
result, err = d.decodeAbstractMessage(r)
if err != nil {
return result, Error("unable to decode abstract for async: %s", err)
}
if err = d.decodeExternal(r, &result, []string{"correlationId", "correlationIdBytes"}); err != nil {
return result, Error("unable to decode async external: %s", err)
}
return
}
// DSK
func (d *Decoder) decodeAcknowledgeMessageExt(r io.Reader) (result Object, err error) {
return d.decodeAcknowledgeMessage(r)
}
func (d *Decoder) decodeAcknowledgeMessage(r io.Reader) (result Object, err error) {
result, err = d.decodeAsyncMessage(r)
if err != nil {
return result, Error("unable to decode async for ack: %s", err)
}
if err = d.decodeExternal(r, &result); err != nil {
return result, Error("unable to decode ack external: %s", err)
}
return
}
// flex.messaging.io.ArrayCollection
func (d *Decoder) decodeArrayCollection(r io.Reader) (interface{}, error) {
result, err := d.DecodeAmf3(r)
if err != nil {
return result, Error("cannot decode child of array collection: %s", err)
}
return result, nil
}
func (d *Decoder) decodeExternal(r io.Reader, obj *Object, fieldSets ...[]string) (err error) {
var flagSet []uint8
var reservedPosition uint8
var fieldNames []string
flagSet, err = readFlags(r)
if err != nil {
return Error("unable to read flags: %s", err)
}
for i, flags := range flagSet {
if i < len(fieldSets) {
fieldNames = fieldSets[i]
} else {
fieldNames = []string{}
}
reservedPosition = uint8(len(fieldNames))
for p, field := range fieldNames {
flagBit := uint8(math.Exp2(float64(p)))
if (flags & flagBit) != 0 {
tmp, err := d.DecodeAmf3(r)
if err != nil {
return Error("unable to decode external field %s %d %d (%#v): %s", field, i, p, flagSet, err)
}
(*obj)[field] = tmp
}
}
if (flags >> reservedPosition) != 0 {
for j := reservedPosition; j < 6; j++ {
if ((flags >> j) & 0x01) != 0 {
field := fmt.Sprintf("extra_%d_%d", i, j)
tmp, err := d.DecodeAmf3(r)
if err != nil {
return Error("unable to decode post-external field %d %d (%#v): %s", i, j, flagSet, err)
}
(*obj)[field] = tmp
}
}
}
}
return
}
func readFlags(r io.Reader) (result []uint8, err error) {
for {
flag, err := ReadByte(r)
if err != nil {
return result, Error("unable to read flags: %s", err)
}
result = append(result, flag)
if (flag & 0x80) == 0 {
break
}
}
return
}
+220
View File
@@ -0,0 +1,220 @@
package amf
import (
"bytes"
"testing"
)
type u29TestCase struct {
value uint32
expect []byte
}
var u29TestCases = []u29TestCase{
{1, []byte{0x01}},
{2, []byte{0x02}},
{127, []byte{0x7F}},
{128, []byte{0x81, 0x00}},
{255, []byte{0x81, 0x7F}},
{256, []byte{0x82, 0x00}},
{0x3FFF, []byte{0xFF, 0x7F}},
{0x4000, []byte{0x81, 0x80, 0x00}},
{0x7FFF, []byte{0x81, 0xFF, 0x7F}},
{0x8000, []byte{0x82, 0x80, 0x00}},
{0x1FFFFF, []byte{0xFF, 0xFF, 0x7F}},
{0x200000, []byte{0x80, 0xC0, 0x80, 0x00}},
{0x3FFFFF, []byte{0x80, 0xFF, 0xFF, 0xFF}},
{0x400000, []byte{0x81, 0x80, 0x80, 0x00}},
{0x0FFFFFFF, []byte{0xBF, 0xFF, 0xFF, 0xFF}},
}
func TestDecodeAmf3Undefined(t *testing.T) {
buf := bytes.NewReader([]byte{0x00})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf3Null(t *testing.T) {
buf := bytes.NewReader([]byte{0x01})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if got != nil {
t.Errorf("expect nil got %v", got)
}
}
func TestDecodeAmf3False(t *testing.T) {
buf := bytes.NewReader([]byte{0x02})
expect := false
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3True(t *testing.T) {
buf := bytes.NewReader([]byte{0x03})
expect := true
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeU29(t *testing.T) {
dec := new(Decoder)
for _, tc := range u29TestCases {
buf := bytes.NewBuffer(tc.expect)
n, err := dec.decodeU29(buf)
if err != nil {
t.Errorf("DecodeAmf3Integer error: %s", err)
}
if n != tc.value {
t.Errorf("DecodeAmf3Integer expect n %x got %x", tc.value, n)
}
}
}
func TestDecodeAmf3Integer(t *testing.T) {
dec := new(Decoder)
buf := bytes.NewReader([]byte{0x04, 0xFF, 0xFF, 0x7F})
expect := int32(2097151)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
buf.Seek(0, 0)
got, err = dec.DecodeAmf3Integer(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
buf.Seek(1, 0)
got, err = dec.DecodeAmf3Integer(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3Double(t *testing.T) {
buf := bytes.NewReader([]byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33})
expect := float64(1.2)
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3String(t *testing.T) {
buf := bytes.NewReader([]byte{0x06, 0x07, 'f', 'o', 'o'})
expect := "foo"
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("%s", err)
}
if expect != got {
t.Errorf("expect %v got %v", expect, got)
}
}
func TestDecodeAmf3Array(t *testing.T) {
buf := bytes.NewReader([]byte{0x09, 0x13, 0x01,
0x06, 0x03, '1',
0x06, 0x03, '2',
0x06, 0x03, '3',
0x06, 0x03, '4',
0x06, 0x03, '5',
0x06, 0x03, '6',
0x06, 0x03, '7',
0x06, 0x03, '8',
0x06, 0x03, '9',
})
dec := new(Decoder)
expect := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}
got, err := dec.DecodeAmf3Array(buf, true)
if err != nil {
t.Errorf("err: %s", err)
}
for i, v := range expect {
if got[i] != v {
t.Error("expected array element %d to be %v, got %v", i, v, got[i])
}
}
}
func TestDecodeAmf3Object(t *testing.T) {
buf := bytes.NewReader([]byte{
0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a',
'm', 'f', '.', 'A', 'S', 'C', 'l', 'a',
's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f',
'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r',
})
dec := new(Decoder)
got, err := dec.DecodeAmf3(buf)
if err != nil {
t.Errorf("err: %s", err)
}
to, ok := got.(Object)
if ok != true {
t.Error("unable to cast object as typed object")
}
if to["foo"] != "bar" {
t.Error("expected foo to be bar, got: %+v", to["foo"])
}
if to["baz"] != nil {
t.Error("expected baz to be nil, got: %+v", to["baz"])
}
}
+308
View File
@@ -0,0 +1,308 @@
package amf
import (
"encoding/binary"
"io"
"reflect"
)
// amf0 polymorphic router
func (e *Encoder) EncodeAmf0(w io.Writer, val interface{}) (int, error) {
if val == nil {
return e.EncodeAmf0Null(w, true)
}
v := reflect.ValueOf(val)
if !v.IsValid() {
return e.EncodeAmf0Null(w, true)
}
switch v.Kind() {
case reflect.String:
str := v.String()
if len(str) <= AMF0_STRING_MAX {
return e.EncodeAmf0String(w, str, true)
} else {
return e.EncodeAmf0LongString(w, str, true)
}
case reflect.Bool:
return e.EncodeAmf0Boolean(w, v.Bool(), true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return e.EncodeAmf0Number(w, float64(v.Int()), true)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return e.EncodeAmf0Number(w, float64(v.Uint()), true)
case reflect.Float32, reflect.Float64:
return e.EncodeAmf0Number(w, float64(v.Float()), true)
case reflect.Array, reflect.Slice:
length := v.Len()
arr := make(Array, length)
for i := 0; i < length; i++ {
arr[i] = v.Index(int(i)).Interface()
}
return e.EncodeAmf0StrictArray(w, arr, true)
case reflect.Map:
obj, ok := val.(Object)
if ok != true {
return 0, Error("encode amf0: unable to create object from map")
}
return e.EncodeAmf0Object(w, obj, true)
}
if _, ok := val.(TypedObject); ok {
return 0, Error("encode amf0: unsupported type typed object")
}
return 0, Error("encode amf0: unsupported type %s", v.Type())
}
// marker: 1 byte 0x00
// format: 8 byte big endian float64
func (e *Encoder) EncodeAmf0Number(w io.Writer, val float64, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_NUMBER_MARKER); err != nil {
return
}
n += 1
}
err = binary.Write(w, binary.BigEndian, &val)
if err != nil {
return
}
n += 8
return
}
// marker: 1 byte 0x01
// format: 1 byte, 0x00 = false, 0x01 = true
func (e *Encoder) EncodeAmf0Boolean(w io.Writer, val bool, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_BOOLEAN_MARKER); err != nil {
return
}
n += 1
}
var m int
buf := make([]byte, 1)
if val {
buf[0] = AMF0_BOOLEAN_TRUE
} else {
buf[0] = AMF0_BOOLEAN_FALSE
}
m, err = w.Write(buf)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x02
// format:
// - 2 byte big endian uint16 header to determine size
// - n (size) byte utf8 string
func (e *Encoder) EncodeAmf0String(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint16(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode string length: %s", err)
}
n += 2
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf0: unable to encode string value: %s", err)
}
n += m
return
}
// marker: 1 byte 0x03
// format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (e *Encoder) EncodeAmf0Object(w io.Writer, val Object, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_OBJECT_MARKER); err != nil {
return
}
n += 1
}
var m int
for k, v := range val {
m, err = e.EncodeAmf0String(w, k, false)
if err != nil {
return n, Error("encode amf0: unable to encode object key: %s", err)
}
n += m
m, err = e.EncodeAmf0(w, v)
if err != nil {
return n, Error("encode amf0: unable to encode object value: %s", err)
}
n += m
}
m, err = e.EncodeAmf0String(w, "", false)
if err != nil {
return n, Error("encode amf0: unable to encode object empty string: %s", err)
}
n += m
err = WriteMarker(w, AMF0_OBJECT_END_MARKER)
if err != nil {
return n, Error("encode amf0: unable to object end marker: %s", err)
}
n += 1
return
}
// marker: 1 byte 0x05
// no additional data
func (e *Encoder) EncodeAmf0Null(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_NULL_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x06
// no additional data
func (e *Encoder) EncodeAmf0Undefined(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_UNDEFINED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x08
// format:
// - 4 byte big endian uint32 with length of associative array
// - normal object format:
// - loop encoded string followed by encoded value
// - terminated with empty string followed by 1 byte 0x09
func (e *Encoder) EncodeAmf0EcmaArray(w io.Writer, val Object, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_ECMA_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode ecma array length: %s", err)
}
n += 4
m, err = e.EncodeAmf0Object(w, val, false)
if err != nil {
return n, Error("encode amf0: unable to encode ecma array object: %s", err)
}
n += m
return
}
// marker: 1 byte 0x0a
// format:
// - 4 byte big endian uint32 to determine length of associative array
// - n (length) encoded values
func (e *Encoder) EncodeAmf0StrictArray(w io.Writer, val Array, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_STRICT_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode strict array length: %s", err)
}
n += 4
for _, v := range val {
m, err = e.EncodeAmf0(w, v)
if err != nil {
return n, Error("encode amf0: unable to encode strict array element: %s", err)
}
n += m
}
return
}
// marker: 1 byte 0x0c
// format:
// - 4 byte big endian uint32 header to determine size
// - n (size) byte utf8 string
func (e *Encoder) EncodeAmf0LongString(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_LONG_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
err = binary.Write(w, binary.BigEndian, length)
if err != nil {
return n, Error("encode amf0: unable to encode long string length: %s", err)
}
n += 4
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf0: unable to encode long string value: %s", err)
}
n += m
return
}
// marker: 1 byte 0x0d
// no additional data
func (e *Encoder) EncodeAmf0Unsupported(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF0_UNSUPPORTED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x11
func (e *Encoder) EncodeAmf0Amf3Marker(w io.Writer) error {
return WriteMarker(w, AMF0_ACMPLUS_OBJECT_MARKER)
}
+212
View File
@@ -0,0 +1,212 @@
package amf
import (
"bytes"
"encoding/binary"
"testing"
)
func TestEncodeAmf0Number(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x00, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, float64(1.2))
if err != nil {
t.Errorf("%s", err)
}
if n != 9 {
t.Errorf("expected to write 9 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0BooleanTrue(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x01}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if n != 2 {
t.Errorf("expected to write 2 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0BooleanFalse(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x01, 0x00}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if n != 2 {
t.Errorf("expected to write 2 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0String(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, "foo")
if err != nil {
t.Errorf("%s", err)
}
if n != 6 {
t.Errorf("expected to write 6 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0Object(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}
enc := new(Encoder)
obj := make(Object)
obj["foo"] = "bar"
n, err := enc.EncodeAmf0(buf, obj)
if err != nil {
t.Errorf("%s", err)
}
if n != 15 {
t.Errorf("expected to write 15 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0EcmaArray(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x00, 0x03, 0x62, 0x61, 0x72, 0x00, 0x00, 0x09}
enc := new(Encoder)
obj := make(Object)
obj["foo"] = "bar"
_, err := enc.EncodeAmf0EcmaArray(buf, obj, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0StrictArray(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x40, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x05}
enc := new(Encoder)
arr := make(Array, 3)
arr[0] = float64(5)
arr[1] = "foo"
arr[2] = nil
_, err := enc.EncodeAmf0StrictArray(buf, arr, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0Null(t *testing.T) {
buf := new(bytes.Buffer)
expect := []byte{0x05}
enc := new(Encoder)
n, err := enc.EncodeAmf0(buf, nil)
if err != nil {
t.Errorf("%s", err)
}
if n != 1 {
t.Errorf("expected to write 1 byte, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf0LongString(t *testing.T) {
buf := new(bytes.Buffer)
testBytes := []byte("12345678")
tbuf := new(bytes.Buffer)
for i := 0; i < 65536; i++ {
tbuf.Write(testBytes)
}
enc := new(Encoder)
_, err := enc.EncodeAmf0(buf, string(tbuf.Bytes()))
if err != nil {
t.Errorf("%s", err)
}
mbuf := make([]byte, 1)
_, err = buf.Read(mbuf)
if err != nil {
t.Errorf("error reading header")
}
if mbuf[0] != 0x0c {
t.Errorf("marker mismatch")
}
var length uint32
err = binary.Read(buf, binary.BigEndian, &length)
if err != nil {
t.Errorf("error reading buffer")
}
if length != (65536 * 8) {
t.Errorf("expected length to be %d, got %d", (65536 * 8), length)
}
tmpBuf := make([]byte, 8)
counter := 0
for buf.Len() > 0 {
n, err := buf.Read(tmpBuf)
if err != nil {
t.Fatalf("test long string result check, read data(%d) error: %s, n: %d", counter, err, n)
}
if n != 8 {
t.Fatalf("test long string result check, read data(%d) n: %d", counter, n)
}
if !bytes.Equal(testBytes, tmpBuf) {
t.Fatalf("test long string result check, read data % x", tmpBuf)
}
counter++
}
}
+431
View File
@@ -0,0 +1,431 @@
package amf
import (
"encoding/binary"
"io"
"reflect"
"sort"
"time"
)
// amf3 polymorphic router
func (e *Encoder) EncodeAmf3(w io.Writer, val interface{}) (int, error) {
if val == nil {
return e.EncodeAmf3Null(w, true)
}
v := reflect.ValueOf(val)
if !v.IsValid() {
return e.EncodeAmf3Null(w, true)
}
switch v.Kind() {
case reflect.String:
return e.EncodeAmf3String(w, v.String(), true)
case reflect.Bool:
if v.Bool() {
return e.EncodeAmf3True(w, true)
} else {
return e.EncodeAmf3False(w, true)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
n := v.Int()
if n >= 0 && n <= AMF3_INTEGER_MAX {
return e.EncodeAmf3Integer(w, uint32(n), true)
} else {
return e.EncodeAmf3Double(w, float64(n), true)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
n := v.Uint()
if n <= AMF3_INTEGER_MAX {
return e.EncodeAmf3Integer(w, uint32(n), true)
} else {
return e.EncodeAmf3Double(w, float64(n), true)
}
case reflect.Int64:
return e.EncodeAmf3Double(w, float64(v.Int()), true)
case reflect.Uint64:
return e.EncodeAmf3Double(w, float64(v.Uint()), true)
case reflect.Float32, reflect.Float64:
return e.EncodeAmf3Double(w, float64(v.Float()), true)
case reflect.Array, reflect.Slice:
length := v.Len()
arr := make(Array, length)
for i := 0; i < length; i++ {
arr[i] = v.Index(int(i)).Interface()
}
return e.EncodeAmf3Array(w, arr, true)
case reflect.Map:
obj, ok := val.(Object)
if ok != true {
return 0, Error("encode amf3: unable to create object from map")
}
to := *new(TypedObject)
to.Object = obj
return e.EncodeAmf3Object(w, to, true)
}
if tm, ok := val.(time.Time); ok {
return e.EncodeAmf3Date(w, tm, true)
}
if to, ok := val.(TypedObject); ok {
return e.EncodeAmf3Object(w, to, true)
}
return 0, Error("encode amf3: unsupported type %s", v.Type())
}
// marker: 1 byte 0x00
// no additional data
func (e *Encoder) EncodeAmf3Undefined(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_UNDEFINED_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x01
// no additional data
func (e *Encoder) EncodeAmf3Null(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_NULL_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x02
// no additional data
func (e *Encoder) EncodeAmf3False(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_FALSE_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x03
// no additional data
func (e *Encoder) EncodeAmf3True(w io.Writer, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_TRUE_MARKER); err != nil {
return
}
n += 1
}
return
}
// marker: 1 byte 0x04
func (e *Encoder) EncodeAmf3Integer(w io.Writer, val uint32, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_INTEGER_MARKER); err != nil {
return
}
n += 1
}
var m int
m, err = e.encodeAmf3Uint29(w, val)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x05
func (e *Encoder) EncodeAmf3Double(w io.Writer, val float64, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_DOUBLE_MARKER); err != nil {
return
}
n += 1
}
err = binary.Write(w, binary.BigEndian, &val)
if err != nil {
return
}
n += 8
return
}
// marker: 1 byte 0x06
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read to complete string.
func (e *Encoder) EncodeAmf3String(w io.Writer, val string, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_STRING_MARKER); err != nil {
return
}
n += 1
}
var m int
m, err = e.encodeAmf3Utf8(w, val)
if err != nil {
return
}
n += m
return
}
// marker: 1 byte 0x08
// format:
// - u29 reference int, if reference, no more data
// - timestamp double
func (e *Encoder) EncodeAmf3Date(w io.Writer, val time.Time, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_DATE_MARKER); err != nil {
return
}
n += 1
}
if err = WriteMarker(w, 0x01); err != nil {
return n, Error("amf3 encode: cannot encode u29 for array: %s", err)
}
n += 1
u64 := float64(val.Unix()) * 1000.0
err = binary.Write(w, binary.BigEndian, &u64)
if err != nil {
return n, Error("amf3 encode: unable to write date double: %s", err)
}
n += 8
return
}
// marker: 1 byte 0x09
// format:
// - u29 reference int. if reference, no more data.
// - string representing associative array if present
// - n values (length of u29)
func (e *Encoder) EncodeAmf3Array(w io.Writer, val Array, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_ARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
u29 := uint32(length<<1) | 0x01
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for array: %s", err)
}
n += m
m, err = e.encodeAmf3Utf8(w, "")
if err != nil {
return n, Error("amf3 encode: cannot encode empty string for array: %s", err)
}
n += m
for _, v := range val {
m, err := e.EncodeAmf3(w, v)
if err != nil {
return n, Error("amf3 encode: cannot encode array element: %s", err)
}
n += m
}
return
}
// marker: 1 byte 0x0a
// format: ugh
func (e *Encoder) EncodeAmf3Object(w io.Writer, val TypedObject, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_OBJECT_MARKER); err != nil {
return
}
n += 1
}
m := 0
trait := *NewTrait()
trait.Type = val.Type
trait.Dynamic = false
trait.Externalizable = false
for k, _ := range val.Object {
trait.Properties = append(trait.Properties, k)
}
sort.Strings(trait.Properties)
var u29 uint32 = 0x03
if trait.Dynamic {
u29 |= 0x02 << 2
}
if trait.Externalizable {
u29 |= 0x01 << 2
}
u29 |= uint32(len(trait.Properties)) << 4
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode trait header for object: %s", err)
}
n += m
m, err = e.encodeAmf3Utf8(w, trait.Type)
if err != nil {
return n, Error("amf3 encode: cannot encode trait type for object: %s", err)
}
n += m
for _, prop := range trait.Properties {
m, err = e.encodeAmf3Utf8(w, prop)
if err != nil {
return n, Error("amf3 encode: cannot encode trait property for object: %s", err)
}
n += m
}
if trait.Externalizable {
return n, Error("amf3 encode: cannot encode externalizable object")
}
for _, prop := range trait.Properties {
m, err = e.EncodeAmf3(w, val.Object[prop])
if err != nil {
return n, Error("amf3 encode: cannot encode sealed object value: %s", err)
}
n += m
}
if trait.Dynamic {
for k, v := range val.Object {
var foundProp bool = false
for _, prop := range trait.Properties {
if prop == k {
foundProp = true
break
}
}
if foundProp != true {
m, err = e.encodeAmf3Utf8(w, k)
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object property key: %s", err)
}
n += m
m, err = e.EncodeAmf3(w, v)
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object value: %s", err)
}
n += m
}
m, err = e.encodeAmf3Utf8(w, "")
if err != nil {
return n, Error("amf3 encode: cannot encode dynamic object ending marker string: %s", err)
}
n += m
}
}
return
}
// marker: 1 byte 0x0c
// format:
// - u29 reference int. if reference, no more data. if not reference,
// length value of bytes to read .
func (e *Encoder) EncodeAmf3ByteArray(w io.Writer, val []byte, encodeMarker bool) (n int, err error) {
if encodeMarker {
if err = WriteMarker(w, AMF3_BYTEARRAY_MARKER); err != nil {
return
}
n += 1
}
var m int
length := uint32(len(val))
u29 := (length << 1) | 1
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for bytearray: %s", err)
}
n += m
m, err = w.Write(val)
if err != nil {
return n, Error("encode amf3: unable to encode bytearray value: %s", err)
}
n += m
return
}
func (e *Encoder) encodeAmf3Utf8(w io.Writer, val string) (n int, err error) {
length := uint32(len(val))
u29 := uint32(length<<1) | 0x01
var m int
m, err = e.encodeAmf3Uint29(w, u29)
if err != nil {
return n, Error("amf3 encode: cannot encode u29 for string: %s", err)
}
n += m
m, err = w.Write([]byte(val))
if err != nil {
return n, Error("encode amf3: unable to encode string value: %s", err)
}
n += m
return
}
func (e *Encoder) encodeAmf3Uint29(w io.Writer, val uint32) (n int, err error) {
if val <= 0x0000007F {
err = WriteByte(w, byte(val))
if err == nil {
n += 1
}
} else if val <= 0x00003FFF {
n, err = w.Write([]byte{byte(val>>7 | 0x80), byte(val & 0x7F)})
} else if val <= 0x001FFFFF {
n, err = w.Write([]byte{byte(val>>14 | 0x80), byte(val>>7&0x7F | 0x80), byte(val & 0x7F)})
} else if val <= 0x1FFFFFFF {
n, err = w.Write([]byte{byte(val>>22 | 0x80), byte(val>>15&0x7F | 0x80), byte(val>>8&0x7F | 0x80), byte(val)})
} else {
return n, Error("amf3 encode: cannot encode u29 with value %d (out of range)", val)
}
return
}
+199
View File
@@ -0,0 +1,199 @@
package amf
import (
"bytes"
"testing"
)
func TestEncodeAmf3EmptyString(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x01}
_, err := enc.EncodeAmf3String(buf, "", false)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Undefined(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x00}
_, err := enc.EncodeAmf3Undefined(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Null(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x01}
_, err := enc.EncodeAmf3(buf, nil)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3False(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x02}
_, err := enc.EncodeAmf3(buf, false)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3True(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x03}
_, err := enc.EncodeAmf3(buf, true)
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Integer(t *testing.T) {
enc := new(Encoder)
for _, tc := range u29TestCases {
buf := new(bytes.Buffer)
_, err := enc.EncodeAmf3Integer(buf, tc.value, false)
if err != nil {
t.Errorf("EncodeAmf3Integer error: %s", err)
}
got := buf.Bytes()
if !bytes.Equal(tc.expect, got) {
t.Errorf("EncodeAmf3Integer expect n %x got %x", tc.value, got)
}
}
buf := new(bytes.Buffer)
expect := []byte{0x04, 0x80, 0xFF, 0xFF, 0xFF}
n, err := enc.EncodeAmf3(buf, uint32(4194303))
if err != nil {
t.Errorf("%s", err)
}
if n != 5 {
t.Errorf("expected to write 5 bytes, actual %d", n)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Double(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x05, 0x3f, 0xf3, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33}
_, err := enc.EncodeAmf3(buf, float64(1.2))
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3String(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x06, 0x07, 'f', 'o', 'o'}
_, err := enc.EncodeAmf3(buf, "foo")
if err != nil {
t.Errorf("%s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Array(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{0x09, 0x13, 0x01,
0x06, 0x03, '1',
0x06, 0x03, '2',
0x06, 0x03, '3',
0x06, 0x03, '4',
0x06, 0x03, '5',
0x06, 0x03, '6',
0x06, 0x03, '7',
0x06, 0x03, '8',
0x06, 0x03, '9',
}
arr := []string{"1", "2", "3", "4", "5", "6", "7", "8", "9"}
_, err := enc.EncodeAmf3(buf, arr)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer: %+v, got: %+v", expect, buf.Bytes())
}
}
func TestEncodeAmf3Object(t *testing.T) {
enc := new(Encoder)
buf := new(bytes.Buffer)
expect := []byte{
0x0a, 0x23, 0x1f, 'o', 'r', 'g', '.', 'a',
'm', 'f', '.', 'A', 'S', 'C', 'l', 'a',
's', 's', 0x07, 'b', 'a', 'z', 0x07, 'f',
'o', 'o', 0x01, 0x06, 0x07, 'b', 'a', 'r',
}
to := *NewTypedObject()
to.Type = "org.amf.ASClass"
to.Object["foo"] = "bar"
to.Object["baz"] = nil
_, err := enc.EncodeAmf3(buf, to)
if err != nil {
t.Errorf("err: %s", err)
}
if bytes.Compare(buf.Bytes(), expect) != 0 {
t.Errorf("expected buffer:\n%#v\ngot:\n%#v", expect, buf.Bytes())
}
}
+70
View File
@@ -0,0 +1,70 @@
package amf
import (
"bytes"
"fmt"
"log"
)
const (
ADD = 0x0
DEL = 0x3
)
const (
SetDataFrame string = "@setDataFrame"
OnMetaData string = "onMetaData"
)
var setFrameFrame []byte
func init() {
b := bytes.NewBuffer(nil)
encoder := &Encoder{}
if _, err := encoder.Encode(b, SetDataFrame, AMF0); err != nil {
log.Fatal(err)
}
setFrameFrame = b.Bytes()
}
func MetaDataReform(p []byte, flag uint8) ([]byte, error) {
r := bytes.NewReader(p)
decoder := &Decoder{}
switch flag {
case ADD:
v, err := decoder.Decode(r, AMF0)
if err != nil {
return nil, err
}
switch v.(type) {
case string:
vv := v.(string)
if vv != SetDataFrame {
tmplen := len(setFrameFrame)
b := make([]byte, tmplen+len(p))
copy(b, setFrameFrame)
copy(b[tmplen:], p)
p = b
}
default:
return nil, fmt.Errorf("setFrameFrame error")
}
case DEL:
v, err := decoder.Decode(r, AMF0)
if err != nil {
return nil, err
}
switch v.(type) {
case string:
vv := v.(string)
if vv == SetDataFrame {
p = p[len(setFrameFrame):]
}
default:
return nil, fmt.Errorf("metadata error")
}
default:
return nil, fmt.Errorf("invalid flag:%d", flag)
}
return p, nil
}
+92
View File
@@ -0,0 +1,92 @@
package amf
import (
"encoding/json"
"errors"
"fmt"
"io"
)
func DumpBytes(label string, buf []byte, size int) {
fmt.Printf("Dumping %s (%d bytes):\n", label, size)
for i := 0; i < size; i++ {
fmt.Printf("0x%02x ", buf[i])
}
fmt.Printf("\n")
}
func Dump(label string, val interface{}) error {
json, err := json.MarshalIndent(val, "", " ")
if err != nil {
return Error("Error dumping %s: %s", label, err)
}
fmt.Printf("Dumping %s:\n%s\n", label, json)
return nil
}
func Error(f string, v ...interface{}) error {
return errors.New(fmt.Sprintf(f, v...))
}
func WriteByte(w io.Writer, b byte) (err error) {
bytes := make([]byte, 1)
bytes[0] = b
_, err = WriteBytes(w, bytes)
return
}
func WriteBytes(w io.Writer, bytes []byte) (int, error) {
return w.Write(bytes)
}
func ReadByte(r io.Reader) (byte, error) {
bytes, err := ReadBytes(r, 1)
if err != nil {
return 0x00, err
}
return bytes[0], nil
}
func ReadBytes(r io.Reader, n int) ([]byte, error) {
bytes := make([]byte, n)
m, err := r.Read(bytes)
if err != nil {
return bytes, err
}
if m != n {
return bytes, fmt.Errorf("decode read bytes failed: expected %d got %d", m, n)
}
return bytes, nil
}
func WriteMarker(w io.Writer, m byte) error {
return WriteByte(w, m)
}
func ReadMarker(r io.Reader) (byte, error) {
return ReadByte(r)
}
func AssertMarker(r io.Reader, checkMarker bool, m byte) error {
if checkMarker == false {
return nil
}
marker, err := ReadMarker(r)
if err != nil {
return err
}
if marker != m {
return Error("decode assert marker failed: expected %v got %v", m, marker)
}
return nil
}
+92
View File
@@ -0,0 +1,92 @@
package cache
import (
"errors"
"flag"
"fmt"
"github.com/H0RlZ0N/gortmppush/av"
)
var (
gopNum = flag.Int("gopNum", 1, "gop num")
)
// Cache ...
type Cache struct {
gop *GopCache
videoSeq *av.Packet
audioSeq *av.Packet
metadata *av.Packet
}
// NewCache ...
func NewCache() *Cache {
return &Cache{
gop: NewGopCache(*gopNum),
videoSeq: nil,
audioSeq: nil,
metadata: nil,
}
}
func (cache *Cache) Write(p *av.Packet) {
switch p.PacketType {
case av.PacketTypeAudio:
// 目前只处理aac的sequence header,如果后续要支持更多的格式
// 可在此添加
if p.AHeader.SoundFormat == av.SOUND_AAC && p.AHeader.AACPacketType == av.AAC_SEQHDR {
cache.audioSeq = p
return
}
case av.PacketTypeVideo:
// 这里目前只处理h264的sequence和gop缓存
if p.VHeader.CodecID == av.VIDEO_H264 {
if p.VHeader.FrameType == av.FRAME_KEY {
if p.VHeader.AVCPacketType == av.AVC_SEQHDR {
cache.videoSeq = p
} else {
cache.gop.Write(p, true)
}
}
cache.gop.Write(p, false)
}
case av.PacketTypeMetadata:
cache.metadata = p
}
}
// Send ...
func (cache *Cache) Send(inputChan chan<- *av.Packet) error {
cachePkts := make([]*av.Packet, 3)
cachePkts = cachePkts[:0]
if cache.metadata != nil {
cachePkts = append(cachePkts, cache.metadata)
}
if cache.videoSeq != nil {
cachePkts = append(cachePkts, cache.videoSeq)
}
if cache.audioSeq != nil {
cachePkts = append(cachePkts, cache.audioSeq)
}
// 发送sequence header
for _, pkt := range cachePkts {
select {
case inputChan <- pkt:
fmt.Println("Input pkt....")
default:
return errors.New("send sequence failed")
}
}
// 发送视频帧
for _, pkt := range cache.gop.gops {
select {
case inputChan <- pkt:
fmt.Println("Input pkt....")
default:
}
}
return nil
}
+33
View File
@@ -0,0 +1,33 @@
package cache
import (
"github.com/H0RlZ0N/gortmppush/av"
)
// GopCache ...
type GopCache struct {
maxNumber int
count int
gops []*av.Packet
}
// NewGopCache ...
func NewGopCache(maxNumber int) *GopCache {
return &GopCache{
count: 0,
gops: make([]*av.Packet, 0),
}
}
func (gc *GopCache) Write(p *av.Packet, bKeyFrame bool) {
if bKeyFrame {
gc.gops = gc.gops[:0]
gc.count = 0
}
// todo 是否需要拷贝
if gc.count < gc.maxNumber {
gc.gops = append(gc.gops, p)
gc.count++
}
}
+50
View File
@@ -0,0 +1,50 @@
package cache
// import (
// "bytes"
// "log"
// "github.com/H0RlZ0N/gortmppush/av"
// "github.com/H0RlZ0N/gortmppush/protocol"
// "github.com/H0RlZ0N/gortmppush/protocol/amf"
// )
// const (
// SetDataFrame string = "@setDataFrame"
// OnMetaData string = "onMetaData"
// )
// var setFrameFrame []byte
// func init() {
// b := bytes.NewBuffer(nil)
// encoder := &amf.Encoder{}
// if _, err := encoder.Encode(b, SetDataFrame, amf.AMF0); err != nil {
// log.Fatal(err)
// }
// setFrameFrame = b.Bytes()
// }
// // SpecialCache ...
// type SpecialCache struct {
// full bool
// p *av.Packet
// }
// // NewSpecialCache ...
// func NewSpecialCache() *SpecialCache {
// return &SpecialCache{}
// }
// func (specialCache *SpecialCache) Write(p *av.Packet) {
// specialCache.p = p
// specialCache.full = true
// }
// // Send ...
// func (specialCache *SpecialCache) Send(w protocol.WriteCloser) error {
// if !specialCache.full {
// return nil
// }
// return w.Write(specialCache.p)
// }
+424
View File
@@ -0,0 +1,424 @@
package core
import (
"encoding/binary"
"fmt"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/utils"
)
/*
+--------------+----------------+--------------------+--------------+
| Basic Header | Message Header | Extended Timestamp | Chunk Data |
+--------------+----------------+--------------------+--------------+
| |
|<------------------- Chunk Header ----------------->|
Chunk Format
Basic Header(1-3字节)这个字段包含块流IDcsid和块类型fmtfmt取值0-3定义4中不同的块消息类型
rtmp协议支持用户自定义[3,65599]之间的CSID012由协议保留表示特殊信息
0-代表basic header总共要占用2个字节csid在[64,319]之间
1-代表占用3个字节csid在6465599之间
2-代表chunk是控制信息和一些命令信息
如果第一个字节的csid取值大于2则说明这个就是csid
0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+
|fmt| cs id |
+-+-+-+-+-+-+-+-+
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|fmt| 0 | cs id - 64 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|fmt| 1 | cs id - 64 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Message Header(0,3,7,11字节)这个字段包含被发送的消息信息无论是全部还是部分字段长度由块头中的块类型fmt来决定
类型0--有11个字节组成其他三种能表示的数据它都能表示但在chunk stream的开始的第一个chunk和头信息中的时间戳后腿即值与上
一个chunk相比减小通常在回退播放的时候会出现这种情况的时候必须采用这种格式
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| timestamp |message length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| message length (coutinue) |message type id| msg stream id |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| msg stream id |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
timestamp时间戳占用3个字节因此它最多能表示到16777215=0xFFFFFF=2^24-1当它
的值超过这个最大值时这三个字节都置为1这样实际的timestamp会转存到 Extended
Timestamp 字段中接收端在判断timestamp字段24个位都为1时就会去Extended Timestamp
中解析实际的时间戳
message length消息数据长度占用3个字节表示实际发送的消息的数据如音频帧视频
帧等数据的长度单位是字节注意这里是Message的长度也就是chunk属于的Message的总长
而不是chunk本身data的长度
message type id(消息的类型id)1个字节表示实际发送的数据的类型主要分为一下几类
* 协议控制消息
SetChunkSize(type id=1):设置chunk中Data字段所能承载的最大字节数默认是128Bytes通信过程中可以通过发送该消息来设置chunk size的大小(不小于128B)
而且通信的双方各自维护一个chunksize两端的chunksize是独立的比如当A想向B发送一个200B的message但默认的chunksize是128B因此就要将消息拆分为
Data分别为128B和72B的两个chunk发送如果此时先发送一个设置chunksize为256B的消息再发送Data为200B的chunk本地不再划分messageB接收到的setchunksize
的协议控制消息会调整接收的chunk的Data的大小
Abort Message(type id=2)当一个Message被切分为多个chunk接收端只接收到部分chunk是发送该控制消息表示发送端不在传输痛Message的chunk接收端接收到这个
消息后要丢弃这些不完整的chunkData数据中只需要一个CSID表示丢弃该CSID的所有已接收到的chunk
Acknowledgement(type id=3): 当接收到对端的消息大小等于窗口大小(window size)时接收端要回馈一个ack给发送端告知对方可以继续发送数据窗口大小就是指接收到
接收端返回的ack前最多可以发送的字节数量返回的ack中会带有从发送上衣额ack后接收到的字节数
Window Acknowledgement Size(type id=5): 发送端在接收到接收端返回的连个ack间最多可以发送的字节数
Set Peer Bandwidth(type id=6): 限制对端的输出带宽接收端接收到该消息后会通过设置消息中的Window ACK Size来限制已发送但未接收到反馈的消息的大小来限制
发送端的发送带宽如果消息中的Window Ack Size与上一次发送给发送端的size不通的话要回馈一个Window Acknowledgement Size的控制消息
Hard(Limit Type=0)接收端应该将Window Ack Size设置为消息中的值
Soft(Limit Type=1)接收端可以将Window Ack Size设置为消息中的值也可以保存原来的值(前提是原来的size小于该控制消息中的Window Ack Size)
Dynamic(Limit Type=2)如果上次的Set Peer BandWidth消息中的Limit Type为0本次也按Hard处理否则忽略本消息不去设置Window Ack Size
* 数据消息
8--音频数据 9--视频数据 18--Metadata 包括音视频编码视频宽高等信息
* 命令消息
2017, 此类消息主要有NetConnection和NetStream两类两个类分别有多个函数
消息的调用可理解为远程函数调用
9代表视频数据
message stream id(消息的流id)4个字节表示该chunk所在的流的ID和Basic Header
的CSID一样它采用小端存储方式
类型1-由7个字节组成省去来表示message stream id的4个字节表示此chunk和上一次发的chunk所在的流相同如果在发送端值和对端
有一个流连接的时候可以尽量去采用这种格式
0 1 2 3
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| timestamp delta |message length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| message length (coutinue) |message type id|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
timestamp delta3 bytes这里和type=0时不同存储的是和上一个chunk的时间差类似
上面提到的timestamp当它的值超过3个字节所能表示的最大值时三个字节都置为1实际
的时间戳差值就会转存到Extended Timestamp字段中接收端在判断timestamp delta字段24
个bit都为1时就会去Extended Timestamp 中解析实际的与上次时间戳的差值
其他字段与上面的解释相同.
类型2-type 2 时占用 3 个字节相对于 type = 1 格式又省去了表示消息长度的3个字节和表示消息类型的1个字节表示此 chunk
和上一次发送的 chunk 所在的流消息的长度和消息的类型都相同余下的这三个字节表示 timestamp delta使用同type=1
0 1 2
0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| timestamp delta |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
类型3-为0字节表示这个chunk的Message Header和上一个是完全相同的当它跟在type=0的chunk后面时表示和前一
chunk 的时间戳都是相同什么时候连时间戳都是相同呢就是一个 Message 拆分成多个 chunk这个 chunk 和上
一个 chunk 同属于一个 Message而当它跟在 type = 1 type = 2 的chunk后面时的chunk后面时表示和前一个 chunk
的时间戳的差是相同的比如第一个 chunk type = 0timestamp = 100第二个 chunk type = 2
timestamp delta = 20表示时间戳为 100 + 20 = 120第三个 chunk type = 3表示 timestamp delta = 20,
时间戳为 120 + 20 = 140
Extended Timestamp04字节这个字段是否存在取决于块消息头中编码的时间戳
chunk 中会有时间戳 timestamp 和时间戳差 timestamp delta并且它们不会同时存在只有这两者之一大于3字节能表示的
最大数值 0xFFFFFF 16777215 才会用这个字段来表示真正的时间戳否则这个字段为 0扩展时间戳占 4 个字节
能表示的最大数值就是 0xFFFFFFFF 4294967295当扩展时间戳启用时timestamp字段或者timestamp delta要全置为1
而不是减去时间戳或者时间戳差的值
Chunk Data可变大小当前块的有效数据上限为配置的最大块大小
*/
//ChunkStream 表示一个完整的message
type ChunkStream struct {
Format uint32 //2bit 代表chunk message type
CSID uint32 //chunk stream id
Timestamp uint32 //时间戳
Length uint32
TypeID uint32
StreamID uint32
Data []byte
Pts uint32
Dts uint32
timeDelta uint32 //时间戳扩展
exted bool
index uint32
remain uint32
complete bool
tmpFromat uint32
}
func (cs *ChunkStream) isComplete() bool {
return cs.complete
}
func (cs *ChunkStream) alloc() {
cs.complete = false
cs.index = 0
cs.remain = cs.Length
if len(cs.Data) < int(cs.Length) {
cs.Data = make([]byte, cs.Length)
}
}
func (cs *ChunkStream) writeHeader(w *ReadWriter) error {
//Chunk Basic Header
h := cs.Format << 6
switch {
case cs.CSID < 64:
h |= cs.CSID
w.WriteUintBE(h, 1)
case cs.CSID-64 < 256:
h |= 0
w.WriteUintBE(h, 1)
w.WriteUintLE(cs.CSID-64, 1)
case cs.CSID-64 < 65536:
h |= 1
w.WriteUintBE(h, 1)
w.WriteUintLE(cs.CSID-64, 2)
}
//Chunk Message Header
ts := cs.Timestamp
if cs.Format == 3 {
goto END
}
if cs.Timestamp > 0xffffff {
ts = 0xffffff
}
w.WriteUintBE(ts, 3)
if cs.Format == 2 {
goto END
}
if cs.Length > 0xffffff {
return fmt.Errorf("length=%d", cs.Length)
}
w.WriteUintBE(cs.Length, 3)
w.WriteUintBE(cs.TypeID, 1)
if cs.Format == 1 {
goto END
}
w.WriteUintLE(cs.StreamID, 4)
END:
//Extended Timestamp
if ts >= 0xffffff {
w.WriteUintBE(cs.Timestamp, 4)
}
return w.WriteError()
}
func (cs *ChunkStream) writeChunk(w *ReadWriter, chunkSize int) error {
if cs.TypeID == av.TAG_AUDIO {
cs.CSID = 4
} else if cs.TypeID == av.TAG_VIDEO ||
cs.TypeID == av.TAG_SCRIPTDATAAMF0 ||
cs.TypeID == av.TAG_SCRIPTDATAAMF3 {
cs.CSID = 6
}
totalLen := uint32(0)
numChunks := (cs.Length / uint32(chunkSize))
for i := uint32(0); i <= numChunks; i++ {
if totalLen == cs.Length {
break
}
if i == 0 {
cs.Format = uint32(0)
} else {
cs.Format = uint32(3)
}
if err := cs.writeHeader(w); err != nil {
return err
}
inc := uint32(chunkSize)
start := uint32(i) * uint32(chunkSize)
if uint32(len(cs.Data))-start <= inc {
inc = uint32(len(cs.Data)) - start
}
totalLen += inc
end := start + inc
buf := cs.Data[start:end]
if _, err := w.Write(buf); err != nil {
return err
}
}
return nil
}
func (cs *ChunkStream) readChunk(r *ReadWriter, chunkSize uint32) error {
var rmark bool = false //表示是不是读取一个message的第一个chunk
if cs.remain == 0 {
//如果一个ChunkStream没有剩余的内容没有读取,那么就应该人为上一个message已经结束
//设置rmark为true
rmark = true
}
var messageHeader [11]byte //message hader最长11个字节长度
var timeExtend bool = false
switch cs.tmpFromat {
case 0: //全类型,一般是一个chunk stream的开始,11个字节长度
if _, err := r.Read(messageHeader[0:]); err != nil {
return fmt.Errorf("read message header failed, %v", err)
}
cs.Format = cs.tmpFromat
cs.Timestamp = utils.U24BE(messageHeader[0:]) //timestamp 3个字节
cs.Length = utils.U24BE(messageHeader[3:]) //3字节
cs.TypeID = uint32(messageHeader[6]) //一个字节长度
cs.StreamID = utils.U32LE(messageHeader[7:]) //4个字节
if cs.Timestamp == 0xFFFFFF {
timeExtend = true
}
case 1: //与上一个属于同一个流, 7个字节长度
if _, err := r.Read(messageHeader[0:7]); err != nil {
return fmt.Errorf("read message header failed, %v", err)
}
cs.Format = cs.tmpFromat
cs.timeDelta = utils.U24BE(messageHeader[0:]) //timeDelta 3个字节
cs.Length = utils.U24BE(messageHeader[3:]) //3字节
cs.TypeID = uint32(messageHeader[6]) //一个字节长度
if cs.timeDelta == 0xFFFFFF {
timeExtend = true
}
case 2: //3个字节长度
if _, err := r.Read(messageHeader[0:3]); err != nil {
return fmt.Errorf("read message header failed, %v", err)
}
cs.Format = cs.tmpFromat
cs.timeDelta = utils.U24BE(messageHeader[0:]) //timeDelta 3个字节
if cs.timeDelta == 0xFFFFFF {
timeExtend = true
}
case 3: //0个字节长度
if cs.timeDelta == 0xFFFFFF {
timeExtend = true
}
default:
return fmt.Errorf("invalid fmt type:%d", cs.tmpFromat)
}
//如果有扩展时间戳,读取扩展时间戳
if timeExtend {
if _, err := r.Read(messageHeader[0:4]); err != nil {
return fmt.Errorf("read time extend failed, %v", err)
}
}
//如果是第一个chunk,设置pts的值,分配空间
if rmark {
cs.alloc()
switch cs.tmpFromat {
case 0:
if timeExtend {
cs.Pts = utils.U32BE(messageHeader[0:])
} else {
cs.Pts = cs.Timestamp
}
case 1, 2, 3:
if timeExtend {
cs.Pts += utils.U32BE(messageHeader[0:])
} else {
cs.Pts += cs.timeDelta
}
}
}
size := int(cs.remain)
if size > int(chunkSize) {
size = int(chunkSize)
}
dataBuf := cs.Data[cs.index : cs.index+uint32(size)]
//读取数据
if _, err := r.Read(dataBuf); err != nil {
return fmt.Errorf("read chunk data failed, %v", err)
}
cs.index += uint32(size)
cs.remain -= uint32(size)
if cs.remain == 0 {
cs.complete = true
}
return nil
}
func (cs *ChunkStream) readChunk1(r *ReadWriter, chunkSize uint32, pool *utils.Pool) error {
if cs.remain != 0 && cs.tmpFromat != 3 {
//如果remain != 0,说明还有消息没有读取完,所有fmt为类型3?
return fmt.Errorf("inlaid remin = %d", cs.remain)
}
switch cs.tmpFromat {
case 0: //全类型,一般是一个chunk stream的开始
cs.Format = cs.tmpFromat
cs.Timestamp, _ = r.ReadUintBE(3)
cs.Length, _ = r.ReadUintBE(3)
cs.TypeID, _ = r.ReadUintBE(1)
cs.StreamID, _ = r.ReadUintLE(4)
if cs.Timestamp == 0xffffff {
cs.Timestamp, _ = r.ReadUintBE(4)
cs.exted = true
} else {
cs.exted = false
}
cs.alloc()
case 1: //与上一个属于同一个流
cs.Format = cs.tmpFromat
timeStamp, _ := r.ReadUintBE(3)
cs.Length, _ = r.ReadUintBE(3)
cs.TypeID, _ = r.ReadUintBE(1)
if timeStamp == 0xffffff {
timeStamp, _ = r.ReadUintBE(4)
cs.exted = true
} else {
cs.exted = false
}
cs.timeDelta = timeStamp
cs.Timestamp += timeStamp
cs.alloc()
case 2: //时间戳不一样
cs.Format = cs.tmpFromat
timeStamp, _ := r.ReadUintBE(3)
if timeStamp == 0xffffff {
timeStamp, _ = r.ReadUintBE(4)
cs.exted = true
} else {
cs.exted = false
}
cs.timeDelta = timeStamp
cs.Timestamp += timeStamp
cs.alloc()
case 3: //都一样
if cs.remain == 0 {
//如果cs.remain == 0,表示是该message的第一个包,要处理时间戳,
//所有的同一个message的chunk,时间戳应该是一样的,只处理第一个就可以了
switch cs.Format {
case 0:
if cs.exted {
timestamp, _ := r.ReadUintBE(4)
cs.Timestamp = timestamp
}
case 1, 2:
var timedet uint32
if cs.exted {
timedet, _ = r.ReadUintBE(4)
} else {
timedet = cs.timeDelta
}
cs.Timestamp += timedet
}
cs.alloc()
} else {
if cs.exted {
b, err := r.Peek(4)
if err != nil {
return err
}
tmpts := binary.BigEndian.Uint32(b)
if tmpts == cs.Timestamp {
r.Discard(4)
}
}
}
default:
return fmt.Errorf("invalid format=%d", cs.Format)
}
size := int(cs.remain)
if size > int(chunkSize) {
size = int(chunkSize)
}
buf := cs.Data[cs.index : cs.index+uint32(size)]
//读取数据
if _, err := r.Read(buf); err != nil {
return err
}
cs.index += uint32(size)
cs.remain -= uint32(size)
if cs.remain == 0 {
cs.complete = true
}
return r.readError
}
+694
View File
@@ -0,0 +1,694 @@
package core
import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
"net"
neturl "net/url"
"strings"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/amf"
)
var (
respResult = "_result"
respError = "_error"
onStatus = "onStatus"
publishStart = "NetStream.Publish.Start"
playStart = "NetStream.Play.Start"
playReset = "NetStream.Play.Reset"
connectSuccess = "NetConnection.Connect.Success"
onBWDone = "onBWDone"
)
var (
errFail = errors.New("respone err")
)
// ConnClient ...
type ConnClient struct {
done bool
transID int
url string
tcurl string
app string
title string
query string
curcmdName string
streamid uint32
conn *RtmpConn
encoder *amf.Encoder
decoder *amf.Decoder
bytesw *bytes.Buffer
logger logger.Logger
}
// NewConnClient ...
func NewConnClient(log logger.Logger) *ConnClient {
return &ConnClient{
transID: 1, //todo 写死?
bytesw: bytes.NewBuffer(nil),
encoder: &amf.Encoder{},
decoder: &amf.Decoder{},
logger: log,
}
}
// DecodeBatch ...
func (cc *ConnClient) DecodeBatch(r io.Reader, ver amf.Version) (ret []interface{}, err error) {
return cc.decoder.DecodeBatch(r, ver)
}
func (cc *ConnClient) Decode(r io.Reader, ver amf.Version) (interface{}, error) {
return cc.decoder.Decode(r, ver)
}
// todo 需要完善,功能不完整
func (cc *ConnClient) waitForResponse(commandName string) error {
for {
//读取一个完整的一个message
cs, err := cc.conn.Read()
if err != nil {
return fmt.Errorf("read chunk stream failed, %v", err)
}
switch cs.TypeID {
case 18, 15: //数据消息,传递一些元数据 amf0-18, amf3-15
//理论上不应该出现数据消息
return errors.New("metadata message should not received")
case 19, 16: //共享对象消息, afm0-19, afm3-16
//忽略共享消息??
cc.logger.Warn("shared message received.")
continue
case 8, 9: //音视频消息, 8-音频数据 9-视频数据
//不应该出现音视频消息
return errors.New("video and audio message should not received")
case 22: //组合消息
//忽略组合消息??
cc.logger.Warn("aggregage message received.")
case 4: //用户控制消息
//发送connect后,会接收到用户控制消息,比如Stream Begin
//todo 如何解析用户消息
cc.logger.Warn("user control message received.")
continue //忽略该消息
case 20, 17: //控制消息 amf0-20, amf3-17
var vs []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == 20 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == 17 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("decode chunk stream failed, %v", err)
}
switch commandName {
case cmdConnect:
var bResult, bTransID, bCode bool
for k, v := range vs {
if result, ok := v.(string); ok && result == respResult {
bResult = true
} else if transID, ok := v.(float64); ok {
if k == 1 && int(transID) == cc.transID {
bTransID = true
}
} else if objmap, ok := v.(amf.Object); ok {
if obj, ok := objmap["code"]; ok {
if code, ok := obj.(string); ok && code == connectSuccess {
bCode = true
}
}
}
}
if !bResult || !bTransID || !bCode {
return fmt.Errorf("result:%v transID:%v code:%v", bResult, bTransID, bCode)
}
case cmdCreateStream:
var bResult, bTransID bool
for k, v := range vs {
if result, ok := v.(string); ok && result == respResult {
bResult = true
} else if id, ok := v.(float64); ok {
if k == 1 && int(id) == cc.transID {
bTransID = true
}
if k == 3 {
cc.streamid = uint32(id)
}
}
}
if !bResult || !bTransID {
return fmt.Errorf("result:%v transID:%v", bResult, bTransID)
}
case cmdPlay:
var bResult, bStart, bReset bool
for _, v := range vs {
if result, ok := v.(string); ok && result == onStatus {
bResult = true
} else if objmap, ok := v.(amf.Object); ok {
if obj, ok := objmap["code"]; ok {
if code, ok := obj.(string); ok {
if code == playReset {
bReset = true
} else if code == playStart {
bStart = true
}
}
}
}
}
if !bResult || (!bStart && !bReset) {
return fmt.Errorf("result:%v start:%v reset:%v", bResult, bStart, bReset)
}
case cmdPublish:
var bResult, bStart bool
for _, v := range vs {
if result, ok := v.(string); ok && result == onStatus {
bResult = true
} else if objmap, ok := v.(amf.Object); ok {
if obj, ok := objmap["code"]; ok {
if code, ok := obj.(string); ok && code == publishStart {
bStart = true
}
}
}
}
if !bResult || !bStart {
return fmt.Errorf("result:%v code:%v", bResult, bStart)
}
default:
return fmt.Errorf("unknow command:%s", commandName)
}
}
return nil
}
}
func (cc *ConnClient) readRespMsg() error {
for {
//读取一个chunk
rc, err := cc.conn.Read()
if err != nil {
if err != io.EOF {
return err
}
}
switch rc.TypeID {
case 20, 17: //如果是控制消息
r := bytes.NewReader(rc.Data)
vs, _ := cc.decoder.DecodeBatch(r, amf.AMF0)
for k, v := range vs {
switch v.(type) {
case string:
switch cc.curcmdName {
case cmdConnect, cmdCreateStream:
if v.(string) != respResult {
return errors.New(v.(string))
}
case cmdPublish:
if v.(string) != onStatus {
return errFail
}
}
case float64:
switch cc.curcmdName {
case cmdConnect, cmdCreateStream:
id := int(v.(float64))
if k == 1 {
if id != cc.transID {
return errFail
}
} else if k == 3 {
cc.streamid = uint32(id)
}
case cmdPublish:
if int(v.(float64)) != 0 {
return errFail
}
}
case amf.Object:
objmap := v.(amf.Object)
switch cc.curcmdName {
case cmdConnect:
code, ok := objmap["code"]
if ok && code.(string) != connectSuccess {
return errFail
}
case cmdPublish:
code, ok := objmap["code"]
if ok && code.(string) != publishStart {
return errFail
}
}
}
}
return nil
}
}
}
func (cc *ConnClient) writeMsg(args ...interface{}) error {
cc.bytesw.Reset()
for _, v := range args {
if _, err := cc.encoder.Encode(cc.bytesw, v, amf.AMF0); err != nil {
return err
}
}
msg := cc.bytesw.Bytes()
c := ChunkStream{
Format: 0,
CSID: 3,
Timestamp: 0,
TypeID: 20,
StreamID: cc.streamid,
Length: uint32(len(msg)),
Data: msg,
}
cc.conn.Write(&c)
return cc.conn.Flush()
}
func (cc *ConnClient) writeConnectMsg() error {
event := make(amf.Object)
event["app"] = cc.app
event["type"] = "nonprivate"
event["flashVer"] = "FMS.3.1"
event["tcUrl"] = cc.tcurl
cc.curcmdName = cmdConnect
cc.logger.Tracef("writeConnectMsg: connClient.transID=%d, event=%v", cc.transID, event)
if err := cc.writeMsg(cmdConnect, cc.transID, event); err != nil {
return err
}
return nil
}
func (cc *ConnClient) writeCreateStreamMsg() error {
cc.transID++
cc.curcmdName = cmdCreateStream
cc.logger.Tracef("writeCreateStreamMsg: connClient.transID=%d", cc.transID)
if err := cc.writeMsg(cmdCreateStream, cc.transID, nil); err != nil {
return err
}
return nil
}
func (cc *ConnClient) writePublishMsg() error {
cc.transID++
cc.curcmdName = cmdPublish
if err := cc.writeMsg(cmdPublish, cc.transID, nil, cc.title, publishLive); err != nil {
return err
}
return nil
}
func (cc *ConnClient) writePlayMsg() error {
cc.transID++
cc.curcmdName = cmdPlay
cc.logger.Tracef("writePlayMsg: connClient.transID=%d, cmdPlay=%v, connClient.title=%v",
cc.transID, cmdPlay, cc.title)
if err := cc.writeMsg(cmdPlay, 0, nil, cc.title); err != nil {
return err
}
return nil
}
func (cc *ConnClient) parseURL(url string) (local, remote string, err error) {
var parsedURL *neturl.URL
if parsedURL, err = neturl.Parse(url); err != nil {
err = fmt.Errorf("parse url failed, %v", err)
return
}
cc.url = url
path := strings.TrimLeft(parsedURL.Path, "/")
ps := strings.SplitN(path, "/", 2)
if len(ps) != 2 {
err = fmt.Errorf("path err, %s", path)
return
}
cc.app = ps[0]
cc.title = ps[1]
cc.query = parsedURL.RawQuery
cc.tcurl = "rtmp://" + parsedURL.Host + "/" + cc.app
port := ":1935"
host := parsedURL.Host
local = ":0"
if strings.Index(host, ":") != -1 {
host, port, err = net.SplitHostPort(host)
if err != nil {
return
}
port = ":" + port
}
var ips []net.IP
if ips, err = net.LookupIP(host); err != nil {
err = fmt.Errorf("net.LookupIP failed, %v", err)
return
}
remote = ips[rand.Intn(len(ips))].String()
if strings.Index(remote, ":") == -1 {
remote += port
}
return
}
func (cc *ConnClient) connectServer(url string) error {
localIP, remoteIP, err := cc.parseURL(url)
if err != nil {
return fmt.Errorf("parse url:%s faile, %v", url, err)
}
var localAddr, remoteAddr *net.TCPAddr
if localAddr, err = net.ResolveTCPAddr("tcp", localIP); err != nil {
return fmt.Errorf("net.ResolveTCPAddr localIP failed, %v", err)
} else if remoteAddr, err = net.ResolveTCPAddr("tcp", remoteIP); err != nil {
return fmt.Errorf("net.ResolveTCPAddr remoteIP failed, %v", err)
}
var conn *net.TCPConn
if conn, err = net.DialTCP("tcp", localAddr, remoteAddr); err != nil {
return fmt.Errorf("net.DialTCP failed, %v", err)
}
rtmpConn := NewRtmpConn(conn, 4*1024)
defer func() {
if err != nil {
rtmpConn.Close()
}
}()
cc.logger.Debug("HandsakeClient...")
if err = rtmpConn.HandshakeClient(); err != nil {
return fmt.Errorf("HandshakeClient failed, %v", err)
}
cc.conn = rtmpConn
return nil
}
func (cc *ConnClient) checkResponse(commandName string, values []interface{}) error {
var resultOK bool = false
for k, v := range values {
switch v.(type) {
case string:
if commandName == cmdConnect || commandName == cmdCreateStream {
if v.(string) != respResult {
return errors.New(v.(string))
}
resultOK = true
} else if commandName == cmdPublish {
if v.(string) != onStatus {
return errFail
}
resultOK = true
}
case float64:
if commandName == cmdConnect || commandName == cmdCreateStream {
id := int(v.(float64))
if k == 1 {
if id != cc.transID {
return errFail
}
} else if k == 3 {
cc.streamid = uint32(id)
}
} else if commandName == cmdPublish {
if int(v.(float64)) != 0 {
return errFail
}
}
case amf.Object:
objmap := v.(amf.Object)
if commandName == cmdConnect {
if code, ok := objmap["code"]; ok && code.(string) != connectSuccess {
return errFail
}
} else if commandName == cmdPublish {
if code, ok := objmap["code"]; ok && code.(string) != publishStart {
return errFail
}
}
}
}
if !resultOK {
return fmt.Errorf("check result failed")
}
return nil
}
// netConnection 建立网络连接
// 先发送connect消息,然后等待connect response
func (cc *ConnClient) netConnection() (err error) {
if err := cc.writeConnectMsg(); err != nil {
return fmt.Errorf("write connect message failed, %v", err)
}
var cs *ChunkStream
for {
if cs, err = cc.conn.Read(); err != nil {
return fmt.Errorf("read chunk stream failed, %v", err)
}
switch cs.TypeID {
case 20, 17: //指令消息, amf0-20, amf3-17
//处理connect消息响应
var vs []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == 20 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == 17 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("decode chunk stream failed, %v", err)
}
cc.logger.Tracef("connect response:%v", vs)
if err = cc.checkResponse(cmdConnect, vs); err != nil {
return fmt.Errorf("check connect response failed, %v", err)
}
cc.logger.Trace("check connect response success")
//如果校验通过,就返回
return
case 18, 15: //数据消息,传递一些元数据 amf0-18, amf3-15
//理论上不应该出现数据消息
return errors.New("metadata message should not received")
case 19, 16: //共享对象消息, afm0-19, afm3-16
//忽略共享消息??
cc.logger.Warn("shared message received.")
case 8, 9: //音视频消息, 8-音频数据 9-视频数据
//不应该出现音视频消息
return errors.New("video and audio message should not received")
case 22: //组合消息
//忽略组合消息??
cc.logger.Warn("aggregage message received.")
case 4: //用户控制消息
//发送connect后,会接收到用户控制消息,比如Stream Begin
//todo 如何解析用户消息
cc.logger.Warn("user control message received.")
}
}
}
// streamConnection 建立流连接
func (cc *ConnClient) streamConnection() (err error) {
if err = cc.writeCreateStreamMsg(); err != nil {
return fmt.Errorf("write create stream failed, %v", err)
}
var cs *ChunkStream
for {
if cs, err = cc.conn.Read(); err != nil {
return fmt.Errorf("read chunk stream failed, %v", err)
}
switch cs.TypeID {
case 20, 17: //指令消息, amf0-20, amf3-17
//处理connect消息响应
var vs []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == 20 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == 17 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("decode chunk stream failed, %v", err)
}
cc.logger.Tracef("create stream response:%v", vs)
if err = cc.checkResponse(cmdCreateStream, vs); err != nil {
return fmt.Errorf("check create stream response failed, %v", err)
}
cc.logger.Trace("check create stream response success")
//如果校验通过,就返回
return
case 18, 15: //数据消息,传递一些元数据 amf0-18, amf3-15
//理论上不应该出现数据消息
return errors.New("metadata message should not received")
case 19, 16: //共享对象消息, afm0-19, afm3-16
//忽略共享消息??
cc.logger.Warn("shared message received.")
case 8, 9: //音视频消息, 8-音频数据 9-视频数据
//不应该出现音视频消息
return errors.New("video and audio message should not received")
case 22: //组合消息
//忽略组合消息??
cc.logger.Warn("aggregage message received.")
case 4: //用户控制消息
//发送connect后,会接收到用户控制消息,比如Stream Begin
//todo 如何解析用户消息
cc.logger.Warn("user control message received.")
}
}
}
func (cc *ConnClient) setupPlayOrPublish(method string) (err error) {
if method == av.PLAY {
err = cc.writePlayMsg()
} else if method == av.PUBLISH {
err = cc.writePublishMsg()
} else {
return fmt.Errorf("unsupport method:%s", method)
}
var cs *ChunkStream
for {
if cs, err = cc.conn.Read(); err != nil {
return fmt.Errorf("read chunk stream failed, %v", err)
}
switch cs.TypeID {
case 20, 17: //指令消息, amf0-20, amf3-17
//处理connect消息响应
var vs []interface{}
r := bytes.NewReader(cs.Data)
if cs.TypeID == 20 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF0)
} else if cs.TypeID == 17 {
vs, err = cc.decoder.DecodeBatch(r, amf.AMF3)
}
if err != nil && err != io.EOF {
return fmt.Errorf("decode chunk stream failed, %v", err)
}
cc.logger.Tracef("play response:%v", vs)
if err = cc.checkResponse(cmdPlay, vs); err != nil {
return fmt.Errorf("check play response failed, %v", err)
}
cc.logger.Trace("check play response success")
//如果校验通过,就返回
return
case 18, 15: //数据消息,传递一些元数据 amf0-18, amf3-15
//理论上不应该出现数据消息
return errors.New("metadata message should not received")
case 19, 16: //共享对象消息, afm0-19, afm3-16
//忽略共享消息??
cc.logger.Warn("shared message received.")
case 8, 9: //音视频消息, 8-音频数据 9-视频数据
//不应该出现音视频消息
return errors.New("video and audio message should not received")
case 22: //组合消息
//忽略组合消息??
cc.logger.Warn("aggregage message received.")
case 4: //用户控制消息
//发送connect后,会接收到用户控制消息,比如Stream Begin
//todo 如何解析用户消息
cc.logger.Warn("user control message received.")
}
}
}
// Start ...
func (cc *ConnClient) Start(url string, method string) (err error) {
if err = cc.connectServer(url); err != nil {
return fmt.Errorf("connect to server failed, %v", err)
}
curCommand := cmdConnect
for {
switch curCommand {
case cmdConnect:
if err = cc.writeConnectMsg(); err != nil {
return fmt.Errorf("write connect msg failed, %v", err)
}
case cmdCreateStream:
if err = cc.writeCreateStreamMsg(); err != nil {
return fmt.Errorf("write create stream failed, %v", err)
}
case cmdPlay:
if err = cc.writePlayMsg(); err != nil {
return fmt.Errorf("write play msg failed, %v", err)
}
case cmdPublish:
if err = cc.writePublishMsg(); err != nil {
return fmt.Errorf("write publish msg failed, %v", err)
}
}
cc.logger.Tracef("Send command:%s success", curCommand)
if err = cc.waitForResponse(curCommand); err != nil {
return fmt.Errorf("wait for %s response failed, %v", curCommand, err)
}
cc.logger.Tracef("wait for %s response success", curCommand)
switch curCommand {
case cmdConnect:
curCommand = cmdCreateStream
case cmdCreateStream:
if method == av.PUBLISH {
curCommand = cmdPublish
} else if method == av.PLAY {
curCommand = cmdPlay
} else {
return fmt.Errorf("unsupport method:%s", method)
}
case cmdPlay, cmdPublish:
return nil
}
}
}
func (cc *ConnClient) Write(c *ChunkStream) error {
if c.TypeID == av.TAG_SCRIPTDATAAMF0 || c.TypeID == av.TAG_SCRIPTDATAAMF3 {
var err error
if c.Data, err = amf.MetaDataReform(c.Data, amf.ADD); err != nil {
return err
}
c.Length = uint32(len(c.Data))
}
return cc.conn.Write(c)
}
// Flush ...
func (cc *ConnClient) Flush() error {
return cc.conn.Flush()
}
func (cc *ConnClient) Read() (*ChunkStream, error) {
return cc.conn.Read()
}
// GetStreamInfo ...
func (cc *ConnClient) GetStreamInfo() (app string, name string, url string) {
app = cc.app
name = cc.title
url = cc.url
return
}
// GetStreamID ...
func (cc *ConnClient) GetStreamID() uint32 {
return cc.streamid
}
// Close ...
func (cc *ConnClient) Close() {
cc.conn.Close()
}
+442
View File
@@ -0,0 +1,442 @@
package core
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/amf"
)
var (
publishLive = "live"
publishRecord = "record"
publishAppend = "append"
)
var (
ErrReq = errors.New("req error")
)
var (
cmdConnect = "connect"
cmdFcpublish = "FCPublish"
cmdReleaseStream = "releaseStream"
cmdCreateStream = "createStream"
cmdPublish = "publish"
cmdFCUnpublish = "FCUnpublish"
cmdDeleteStream = "deleteStream"
cmdPlay = "play"
)
// ConnectInfo ...
type ConnectInfo struct {
App string `amf:"app" json:"app"`
Flashver string `amf:"flashVer" json:"flashVer"`
SwfURL string `amf:"swfUrl" json:"swfUrl"`
TcURL string `amf:"tcUrl" json:"tcUrl"`
Fpad bool `amf:"fpad" json:"fpad"`
AudioCodecs int `amf:"audioCodecs" json:"audioCodecs"`
VideoCodecs int `amf:"videoCodecs" json:"videoCodecs"`
VideoFunction int `amf:"videoFunction" json:"videoFunction"`
PageURL string `amf:"pageUrl" json:"pageUrl"`
ObjectEncoding int `amf:"objectEncoding" json:"objectEncoding"`
}
// ConnectResp ...
type ConnectResp struct {
FMSVer string `amf:"fmsVer"`
Capabilities int `amf:"capabilities"`
}
// ConnectEvent ...
type ConnectEvent struct {
Level string `amf:"level"`
Code string `amf:"code"`
Description string `amf:"description"`
ObjectEncoding int `amf:"objectEncoding"`
}
// PublishInfo ...
type PublishInfo struct {
Name string
Type string
}
// ForwardConnect 与客户端对应的rtmp连接
type ForwardConnect struct {
done bool
streamID int
isPublisher bool
conn *RtmpConn
transactionID int
ConnInfo ConnectInfo
PublishInfo PublishInfo
decoder *amf.Decoder
encoder *amf.Encoder
bytesw *bytes.Buffer
logger logger.Logger
}
// NewForwardConnect 创建一个与客户端对应的rtmp连接
func NewForwardConnect(conn *RtmpConn, log logger.Logger) *ForwardConnect {
return &ForwardConnect{
conn: conn,
streamID: 1, //todo
bytesw: bytes.NewBuffer(nil),
decoder: &amf.Decoder{},
encoder: &amf.Encoder{},
logger: log,
}
}
func (fc *ForwardConnect) writeMsg(csid, streamID uint32, args ...interface{}) error {
fc.bytesw.Reset()
for _, v := range args {
if _, err := fc.encoder.Encode(fc.bytesw, v, amf.AMF0); err != nil {
return err
}
}
msg := fc.bytesw.Bytes()
c := ChunkStream{
Format: 0,
CSID: csid,
Timestamp: 0,
TypeID: 20,
StreamID: streamID,
Length: uint32(len(msg)),
Data: msg,
}
fc.conn.Write(&c)
return fc.conn.Flush()
}
func (fc *ForwardConnect) connect(vs []interface{}) error {
for _, v := range vs {
switch v.(type) {
case string:
case float64:
id := int(v.(float64))
if id != 1 {
return ErrReq
}
fc.transactionID = id
case amf.Object:
obimap := v.(amf.Object)
if app, ok := obimap["app"]; ok {
fc.ConnInfo.App = app.(string)
}
if flashVer, ok := obimap["flashVer"]; ok {
fc.ConnInfo.Flashver = flashVer.(string)
}
if tcurl, ok := obimap["tcUrl"]; ok {
fc.ConnInfo.TcURL = tcurl.(string)
}
if encoding, ok := obimap["objectEncoding"]; ok {
fc.ConnInfo.ObjectEncoding = int(encoding.(float64))
}
}
}
return nil
}
func (fc *ForwardConnect) releaseStream(vs []interface{}) error {
return nil
}
func (fc *ForwardConnect) fcPublish(vs []interface{}) error {
return nil
}
// todo 参数是否要定死
func (fc *ForwardConnect) connectResp(cur *ChunkStream) error {
c := fc.conn.NewWindowAckSize(2500000)
fc.conn.Write(&c)
c = fc.conn.NewSetPeerBandwidth(2500000)
fc.conn.Write(&c)
c = fc.conn.NewSetChunkSize(uint32(1024))
fc.conn.Write(&c)
resp := make(amf.Object)
resp["fmsVer"] = "FMS/3,0,1,123"
resp["capabilities"] = 31
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetConnection.Connect.Success"
event["description"] = "Connection succeeded."
event["objectEncoding"] = fc.ConnInfo.ObjectEncoding
return fc.writeMsg(cur.CSID, cur.StreamID, "_result", fc.transactionID, resp, event)
}
func (fc *ForwardConnect) createStream(vs []interface{}) error {
for _, v := range vs {
switch v.(type) {
case string:
case float64:
fc.transactionID = int(v.(float64))
case amf.Object:
}
}
return nil
}
func (fc *ForwardConnect) createStreamResp(cur *ChunkStream) error {
return fc.writeMsg(cur.CSID, cur.StreamID, "_result", fc.transactionID, nil, fc.streamID)
}
func (fc *ForwardConnect) publishOrPlay(vs []interface{}) error {
for k, v := range vs {
switch v.(type) {
case string:
if k == 2 {
fc.PublishInfo.Name = v.(string)
} else if k == 3 {
fc.PublishInfo.Type = v.(string)
}
case float64:
id := int(v.(float64))
fc.transactionID = id
case amf.Object:
}
}
return nil
}
func (fc *ForwardConnect) publishResp(cur *ChunkStream) error {
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetStream.Publish.Start"
event["description"] = "Start publising."
return fc.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event)
}
func (fc *ForwardConnect) playResp(cur *ChunkStream) error {
fc.conn.SetRecorded()
fc.conn.SetBegin()
event := make(amf.Object)
event["level"] = "status"
event["code"] = "NetStream.Play.Reset"
event["description"] = "Playing and resetting stream."
if err := fc.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Play.Start"
event["description"] = "Started playing stream."
if err := fc.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Data.Start"
event["description"] = "Started playing stream."
if err := fc.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
event["level"] = "status"
event["code"] = "NetStream.Play.PublishNotify"
event["description"] = "Started playing notify."
if err := fc.writeMsg(cur.CSID, cur.StreamID, "onStatus", 0, nil, event); err != nil {
return err
}
return fc.conn.Flush()
}
func (fc *ForwardConnect) handleCmdMsg(c *ChunkStream) error {
amfType := amf.AMF0
if c.TypeID == 17 {
c.Data = c.Data[1:]
}
r := bytes.NewReader(c.Data)
vs, err := fc.decoder.DecodeBatch(r, amf.Version(amfType))
if err != nil && err != io.EOF {
return err
}
// glog.Infof("rtmp req: %#v", vs)
switch vs[0].(type) {
case string:
switch vs[0].(string) {
case cmdConnect:
if err = fc.connect(vs[1:]); err != nil {
return err
}
if err = fc.connectResp(c); err != nil {
return err
}
case cmdCreateStream:
if err = fc.createStream(vs[1:]); err != nil {
return err
}
if err = fc.createStreamResp(c); err != nil {
return err
}
case cmdPublish:
if err = fc.publishOrPlay(vs[1:]); err != nil {
return err
}
if err = fc.publishResp(c); err != nil {
return err
}
fc.done = true
fc.isPublisher = true
case cmdPlay:
if err = fc.publishOrPlay(vs[1:]); err != nil {
return err
}
if err = fc.playResp(c); err != nil {
return err
}
fc.done = true
fc.isPublisher = false
fmt.Printf("handle play req done\n")
case cmdFcpublish:
fc.fcPublish(vs)
case cmdReleaseStream:
fc.releaseStream(vs)
case cmdFCUnpublish:
case cmdDeleteStream:
default:
fc.logger.Warnf("no support command:%s", vs[0].(string))
}
}
return nil
}
// SetUpPlayOrPublish 等待客户端完成推流或拉流请求
// todo 需要增加超时,防止连接一直在却不发送任何消息
func (fc *ForwardConnect) SetUpPlayOrPublish() error {
amfType := amf.AMF0
for {
chunk, err := fc.conn.Read()
if err != nil {
return fmt.Errorf("Read chunk stream failed, %v", err)
}
//todo 需要注释一下, 20,17代表什么消息类型
if chunk.TypeID != 17 && chunk.TypeID != 20 {
continue
} else if chunk.TypeID == 17 {
chunk.Data = chunk.Data[1:]
}
r := bytes.NewReader(chunk.Data)
vs, err := fc.decoder.DecodeBatch(r, amf.Version(amfType))
if err != nil && err != io.EOF {
return fmt.Errorf("Amf DecodeBatch failed, %v", err)
}
if cmd, ok := vs[0].(string); ok {
switch cmd {
case cmdConnect:
if err = fc.connect(vs[1:]); err != nil {
return fmt.Errorf("handle connect cmd failed, %v", err)
}
if err = fc.connectResp(chunk); err != nil {
return fmt.Errorf("connect response failed, %v", err)
}
case cmdCreateStream:
if err = fc.createStream(vs[1:]); err != nil {
return fmt.Errorf("handle create stream cmd failed, %v", err)
}
if err = fc.createStreamResp(chunk); err != nil {
return fmt.Errorf("create stream response failed, %v", err)
}
case cmdPublish:
if err = fc.publishOrPlay(vs[1:]); err != nil {
return fmt.Errorf("handle publish command failed, %v", err)
}
if err = fc.publishResp(chunk); err != nil {
return fmt.Errorf("publish response failed, %v", err)
}
fc.isPublisher = true
return nil
case cmdPlay:
if err = fc.publishOrPlay(vs[1:]); err != nil {
return fmt.Errorf("handle play command failed, %v", err)
}
if err = fc.playResp(chunk); err != nil {
return fmt.Errorf("play response failed, %v", err)
}
fc.isPublisher = false
return nil
case cmdFcpublish:
fc.fcPublish(vs)
case cmdReleaseStream:
fc.releaseStream(vs)
case cmdFCUnpublish:
case cmdDeleteStream:
default:
fc.logger.Warnf("no support command:%s", cmd)
}
}
}
}
// ReadMsg is a method
func (fc *ForwardConnect) ReadMsg() error {
for {
c, err := fc.conn.Read()
if err != nil {
return err
}
switch c.TypeID {
case 20, 17:
if err := fc.handleCmdMsg(c); err != nil {
return err
}
}
if fc.done {
break
}
}
return nil
}
// IsPublisher ...
func (fc *ForwardConnect) IsPublisher() bool {
return fc.isPublisher
}
// Write ...
func (fc *ForwardConnect) Write(c ChunkStream) error {
if c.TypeID == av.TAG_SCRIPTDATAAMF0 ||
c.TypeID == av.TAG_SCRIPTDATAAMF3 {
var err error
if c.Data, err = amf.MetaDataReform(c.Data, amf.DEL); err != nil {
return err
}
c.Length = uint32(len(c.Data))
}
return fc.conn.Write(&c)
}
// Flush ...
func (fc *ForwardConnect) Flush() error {
return fc.conn.Flush()
}
// Read ...
func (fc *ForwardConnect) Read() (*ChunkStream, error) {
return fc.conn.Read()
}
// GetStreamInfo ...
func (fc *ForwardConnect) GetStreamInfo() (app string, name string, url string) {
app = fc.ConnInfo.App
name = fc.PublishInfo.Name
url = fc.ConnInfo.TcURL + "/" + fc.PublishInfo.Name
return
}
// Close ...
func (fc *ForwardConnect) Close() {
fc.conn.Close()
}
+209
View File
@@ -0,0 +1,209 @@
package core
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"time"
"github.com/H0RlZ0N/gortmppush/utils"
)
var (
timeout = 5 * time.Second
)
var (
hsClientFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'P', 'l', 'a', 'y', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsServerFullKey = []byte{
'G', 'e', 'n', 'u', 'i', 'n', 'e', ' ', 'A', 'd', 'o', 'b', 'e', ' ',
'F', 'l', 'a', 's', 'h', ' ', 'M', 'e', 'd', 'i', 'a', ' ',
'S', 'e', 'r', 'v', 'e', 'r', ' ',
'0', '0', '1',
0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, 0x2E, 0x00, 0xD0, 0xD1,
0x02, 0x9E, 0x7E, 0x57, 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB,
0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE,
}
hsClientPartialKey = hsClientFullKey[:30]
hsServerPartialKey = hsServerFullKey[:36]
)
func hsMakeDigest(key []byte, src []byte, gap int) (dst []byte) {
h := hmac.New(sha256.New, key)
if gap <= 0 {
h.Write(src)
} else {
h.Write(src[:gap])
h.Write(src[gap+32:])
}
return h.Sum(nil)
}
func hsCalcDigestPos(p []byte, base int) (pos int) {
for i := 0; i < 4; i++ {
pos += int(p[base+i])
}
pos = (pos % 728) + base + 4
return
}
func hsFindDigest(p []byte, key []byte, base int) int {
gap := hsCalcDigestPos(p, base)
digest := hsMakeDigest(key, p, gap)
if bytes.Compare(p[gap:gap+32], digest) != 0 {
return -1
}
return gap
}
func hsParse1(p []byte, peerkey []byte, key []byte) (ok bool, digest []byte) {
var pos int
if pos = hsFindDigest(p, peerkey, 772); pos == -1 {
if pos = hsFindDigest(p, peerkey, 8); pos == -1 {
return
}
}
ok = true
digest = hsMakeDigest(key, p[pos:pos+32], -1)
return
}
func hsCreate01(p []byte, time uint32, ver uint32, key []byte) {
p[0] = 3
p1 := p[1:]
rand.Read(p1[8:])
utils.PutU32BE(p1[0:4], time)
utils.PutU32BE(p1[4:8], ver)
gap := hsCalcDigestPos(p1, 8)
digest := hsMakeDigest(key, p1, gap)
copy(p1[gap:], digest)
}
func hsCreate2(p []byte, key []byte) {
rand.Read(p)
gap := len(p) - 32
digest := hsMakeDigest(key, p, gap)
copy(p[gap:], digest)
}
// HandshakeClient todo comment
func (conn *RtmpConn) HandshakeClient() (err error) {
var random [(1 + 1536*2) * 2]byte
C0C1C2 := random[:1536*2+1]
C0 := C0C1C2[:1]
C0C1 := C0C1C2[:1536+1]
C2 := C0C1C2[1536+1:]
S0S1S2 := random[1536*2+1:]
C0[0] = 3
// > C0C1
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = conn.rw.Write(C0C1); err != nil {
return
}
conn.Conn.SetDeadline(time.Now().Add(timeout))
if err = conn.rw.Flush(); err != nil {
return
}
// < S0S1S2
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(conn.rw, S0S1S2); err != nil {
return
}
S1 := S0S1S2[1 : 1536+1]
if ver := utils.U32BE(S1[4:8]); ver != 0 {
C2 = S1
} else {
C2 = S1
}
// > C2
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = conn.rw.Write(C2); err != nil {
return
}
conn.Conn.SetDeadline(time.Time{})
return
}
// HandshakeServer todo comment
func (conn *RtmpConn) HandshakeServer() (err error) {
var random [(1 + 1536*2) * 2]byte
C0C1C2 := random[:1536*2+1]
C0 := C0C1C2[:1]
C1 := C0C1C2[1 : 1536+1]
C0C1 := C0C1C2[:1536+1]
C2 := C0C1C2[1536+1:]
S0S1S2 := random[1536*2+1:]
S0 := S0S1S2[:1]
S1 := S0S1S2[1 : 1536+1]
S0S1 := S0S1S2[:1536+1]
S2 := S0S1S2[1536+1:]
// < C0C1
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(conn.rw, C0C1); err != nil {
return
}
conn.Conn.SetDeadline(time.Now().Add(timeout))
if C0[0] != 3 {
err = fmt.Errorf("rtmp: handshake version=%d invalid", C0[0])
return
}
S0[0] = 3
clitime := utils.U32BE(C1[0:4])
srvtime := clitime
srvver := uint32(0x0d0e0a0d)
cliver := utils.U32BE(C1[4:8])
if cliver != 0 {
var ok bool
var digest []byte
if ok, digest = hsParse1(C1, hsClientPartialKey, hsServerFullKey); !ok {
err = fmt.Errorf("rtmp: handshake server: C1 invalid")
return
}
hsCreate01(S0S1, srvtime, srvver, hsServerPartialKey)
hsCreate2(S2, digest)
} else {
copy(S1, C2)
copy(S2, C1)
}
// > S0S1S2
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = conn.rw.Write(S0S1S2); err != nil {
return
}
conn.Conn.SetDeadline(time.Now().Add(timeout))
if err = conn.rw.Flush(); err != nil {
return
}
// < C2
conn.Conn.SetDeadline(time.Now().Add(timeout))
if _, err = io.ReadFull(conn.rw, C2); err != nil {
return
}
conn.Conn.SetDeadline(time.Time{})
return
}
+126
View File
@@ -0,0 +1,126 @@
package core
import (
"bufio"
"io"
)
// ReadWriter ...
type ReadWriter struct {
*bufio.ReadWriter
readError error
writeError error
}
// NewReadWriter ...
func NewReadWriter(rw io.ReadWriter, bufSize int) *ReadWriter {
return &ReadWriter{
ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(rw, bufSize),
bufio.NewWriterSize(rw, bufSize)),
}
}
// Read ...
func (rw *ReadWriter) Read(p []byte) (int, error) {
if rw.readError != nil {
return 0, rw.readError
}
n, err := io.ReadAtLeast(rw.ReadWriter, p, len(p))
rw.readError = err
return n, err
}
// ReadError ...
func (rw *ReadWriter) ReadError() error {
return rw.readError
}
// ReadUintBE ...
func (rw *ReadWriter) ReadUintBE(n int) (uint32, error) {
if rw.readError != nil {
return 0, rw.readError
}
ret := uint32(0)
for i := 0; i < n; i++ {
b, err := rw.ReadByte()
if err != nil {
rw.readError = err
return 0, err
}
ret = ret<<8 + uint32(b)
}
return ret, nil
}
// ReadUintLE ...
func (rw *ReadWriter) ReadUintLE(n int) (uint32, error) {
if rw.readError != nil {
return 0, rw.readError
}
ret := uint32(0)
for i := 0; i < n; i++ {
b, err := rw.ReadByte()
if err != nil {
rw.readError = err
return 0, err
}
ret += uint32(b) << uint32(i*8)
}
return ret, nil
}
// Flush ...
func (rw *ReadWriter) Flush() error {
if rw.writeError != nil {
return rw.writeError
}
if rw.ReadWriter.Writer.Buffered() == 0 {
return nil
}
return rw.ReadWriter.Flush()
}
// Write ...
func (rw *ReadWriter) Write(p []byte) (int, error) {
if rw.writeError != nil {
return 0, rw.writeError
}
return rw.ReadWriter.Write(p)
}
// WriteError ...
func (rw *ReadWriter) WriteError() error {
return rw.writeError
}
// WriteUintBE ...
func (rw *ReadWriter) WriteUintBE(v uint32, n int) error {
if rw.writeError != nil {
return rw.writeError
}
for i := 0; i < n; i++ {
b := byte(v>>uint32((n-i-1)<<3)) & 0xff
if err := rw.WriteByte(b); err != nil {
rw.writeError = err
return err
}
}
return nil
}
// WriteUintLE ...
func (rw *ReadWriter) WriteUintLE(v uint32, n int) error {
if rw.writeError != nil {
return rw.writeError
}
for i := 0; i < n; i++ {
b := byte(v) & 0xff
if err := rw.WriteByte(b); err != nil {
rw.writeError = err
return err
}
v = v >> 8
}
return nil
}
+136
View File
@@ -0,0 +1,136 @@
package core
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestReader(t *testing.T) {
at := assert.New(t)
buf := bytes.NewBufferString("abc")
r := NewReadWriter(buf, 1024)
b := make([]byte, 3)
n, err := r.Read(b)
at.Equal(err, nil)
at.Equal(r.ReadError(), nil)
at.Equal(n, 3)
n, err = r.Read(b)
at.Equal(err, io.EOF)
at.Equal(r.ReadError(), io.EOF)
buf.WriteString("123")
n, err = r.Read(b)
at.Equal(err, io.EOF)
at.Equal(r.ReadError(), io.EOF)
at.Equal(n, 0)
}
func TestReaderUintBE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x01, 0x02}},
{3, 0x010203, []byte{0x01, 0x02, 0x03}},
{4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}},
}
for _, test := range tests {
buf := bytes.NewBuffer(test.bytes)
r := NewReadWriter(buf, 1024)
n, err := r.ReadUintBE(test.i)
at.Equal(err, nil, "test %d", test.i)
at.Equal(n, test.value, "test %d", test.i)
}
}
func TestReaderUintLE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x02, 0x01}},
{3, 0x010203, []byte{0x03, 0x02, 0x01}},
{4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}},
}
for _, test := range tests {
buf := bytes.NewBuffer(test.bytes)
r := NewReadWriter(buf, 1024)
n, err := r.ReadUintLE(test.i)
at.Equal(err, nil, "test %d", test.i)
at.Equal(n, test.value, "test %d", test.i)
}
}
func TestWriter(t *testing.T) {
at := assert.New(t)
buf := bytes.NewBuffer(nil)
w := NewReadWriter(buf, 1024)
b := []byte{1, 2, 3}
n, err := w.Write(b)
at.Equal(err, nil)
at.Equal(w.WriteError(), nil)
at.Equal(n, 3)
w.writeError = io.EOF
n, err = w.Write(b)
at.Equal(err, io.EOF)
at.Equal(w.WriteError(), io.EOF)
at.Equal(n, 0)
}
func TestWriteUintBE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x01, 0x02}},
{3, 0x010203, []byte{0x01, 0x02, 0x03}},
{4, 0x01020304, []byte{0x01, 0x02, 0x03, 0x04}},
}
for _, test := range tests {
buf := bytes.NewBuffer(nil)
r := NewReadWriter(buf, 1024)
err := r.WriteUintBE(test.value, test.i)
at.Equal(err, nil, "test %d", test.i)
err = r.Flush()
at.Equal(err, nil, "test %d", test.i)
at.Equal(buf.Bytes(), test.bytes, "test %d", test.i)
}
}
func TestWriteUintLE(t *testing.T) {
at := assert.New(t)
type Test struct {
i int
value uint32
bytes []byte
}
tests := []Test{
{1, 0x01, []byte{0x01}},
{2, 0x0102, []byte{0x02, 0x01}},
{3, 0x010203, []byte{0x03, 0x02, 0x01}},
{4, 0x01020304, []byte{0x04, 0x03, 0x02, 0x01}},
}
for _, test := range tests {
buf := bytes.NewBuffer(nil)
r := NewReadWriter(buf, 1024)
err := r.WriteUintLE(test.value, test.i)
at.Equal(err, nil, "test %d", test.i)
err = r.Flush()
at.Equal(err, nil, "test %d", test.i)
at.Equal(buf.Bytes(), test.bytes, "test %d", test.i)
}
}
+250
View File
@@ -0,0 +1,250 @@
package core
import (
"encoding/binary"
"net"
"time"
"github.com/H0RlZ0N/gortmppush/utils"
)
const (
_ = iota
idSetChunkSize
idAbortMessage
idAck
idUserControlMessages
idWindowAckSize
idSetPeerBandwidth
)
// RtmpConn ...
type RtmpConn struct {
net.Conn
chunkSize uint32
remoteChunkSize uint32
windowAckSize uint32
remoteWindowAckSize uint32
received uint32
ackReceived uint32
rw *ReadWriter
pool *utils.Pool
chunks map[uint32]*ChunkStream
}
// NewRtmpConn ...
func NewRtmpConn(c net.Conn, bufferSize int) *RtmpConn {
return &RtmpConn{
Conn: c,
chunkSize: 128,
remoteChunkSize: 128,
windowAckSize: 2500000,
remoteWindowAckSize: 2500000,
pool: utils.NewPool(),
rw: NewReadWriter(c, bufferSize),
chunks: make(map[uint32]*ChunkStream),
}
}
func (rtmpConn *RtmpConn) Read() (c *ChunkStream, err error) {
var rb byte
for {
//读取第一个字节
if rb, err = rtmpConn.rw.ReadByte(); err != nil {
return nil, err
}
format := uint32(rb >> 6) //获取fmt,前面两位是fmt
csid := uint32(rb & 0x3f) //获取csid,先获取后面6位的csid
switch csid {
case 0: //csid有2个字节,需要再读取一个字节
if rb, err = rtmpConn.rw.ReadByte(); err != nil {
return nil, err
}
csid = uint32(rb) + 64 //从64开始计算
case 1:
if csid, err = rtmpConn.rw.ReadUintLE(2); err != nil {
return nil, err
}
csid += 64 //从64开始计算
case 2: //表示该chunk是控制信息和命令信息,相当于控制消息的csid就是2
default: //该6位就是一个csid 不用处理
}
cs, ok := rtmpConn.chunks[csid]
if !ok { //如果没找到,就创建一个新的chunkstream
cs = &ChunkStream{
CSID: csid,
}
rtmpConn.chunks[csid] = cs
}
cs.tmpFromat = format
if err = cs.readChunk(rtmpConn.rw, rtmpConn.remoteChunkSize); err != nil {
return nil, err
}
//判断当前chunk是否读取完成
if cs.isComplete() {
c = &ChunkStream{
Format: cs.Format,
CSID: cs.CSID,
Timestamp: cs.Timestamp,
Length: cs.Length,
TypeID: cs.TypeID,
StreamID: cs.StreamID,
Data: cs.Data[0:cs.Length],
}
//如果是控制消息,就直接处理掉,不反回到外层
isHandled := rtmpConn.handleControlMsg(cs)
rtmpConn.ack(cs.Length)
if !isHandled {
return
}
}
}
}
func (rtmpConn *RtmpConn) Write(c *ChunkStream) error {
if c.TypeID == idSetChunkSize {
rtmpConn.chunkSize = binary.BigEndian.Uint32(c.Data)
}
return c.writeChunk(rtmpConn.rw, int(rtmpConn.chunkSize))
}
// Flush ...
func (rtmpConn *RtmpConn) Flush() error {
return rtmpConn.rw.Flush()
}
// Close ...
func (rtmpConn *RtmpConn) Close() error {
return rtmpConn.Conn.Close()
}
// RemoteAddr ...
func (rtmpConn *RtmpConn) RemoteAddr() net.Addr {
return rtmpConn.Conn.RemoteAddr()
}
// LocalAddr ...
func (rtmpConn *RtmpConn) LocalAddr() net.Addr {
return rtmpConn.Conn.LocalAddr()
}
// SetDeadline ...
func (rtmpConn *RtmpConn) SetDeadline(t time.Time) error {
return rtmpConn.Conn.SetDeadline(t)
}
// NewAck ...
func (rtmpConn *RtmpConn) NewAck(size uint32) ChunkStream {
return initControlMsg(idAck, 4, size)
}
// NewSetChunkSize ...
func (rtmpConn *RtmpConn) NewSetChunkSize(size uint32) ChunkStream {
return initControlMsg(idSetChunkSize, 4, size)
}
// NewWindowAckSize ...
func (rtmpConn *RtmpConn) NewWindowAckSize(size uint32) ChunkStream {
return initControlMsg(idWindowAckSize, 4, size)
}
// NewSetPeerBandwidth ...
func (rtmpConn *RtmpConn) NewSetPeerBandwidth(size uint32) ChunkStream {
ret := initControlMsg(idSetPeerBandwidth, 5, size)
ret.Data[4] = 2
return ret
}
// handleControlMsg 处理协议层消息
func (rtmpConn *RtmpConn) handleControlMsg(c *ChunkStream) bool {
switch c.TypeID {
case idSetChunkSize:
rtmpConn.remoteChunkSize = binary.BigEndian.Uint32(c.Data)
case idAbortMessage:
case idAck:
case idUserControlMessages:
case idWindowAckSize:
rtmpConn.remoteWindowAckSize = binary.BigEndian.Uint32(c.Data)
case idSetPeerBandwidth:
default:
return false
}
return true
}
func (rtmpConn *RtmpConn) ack(size uint32) {
rtmpConn.received += uint32(size)
rtmpConn.ackReceived += uint32(size)
if rtmpConn.received >= 0xf0000000 {
rtmpConn.received = 0
}
if rtmpConn.ackReceived >= rtmpConn.remoteWindowAckSize {
cs := rtmpConn.NewAck(rtmpConn.ackReceived)
cs.writeChunk(rtmpConn.rw, int(rtmpConn.chunkSize))
rtmpConn.ackReceived = 0
}
}
func initControlMsg(id, size, value uint32) ChunkStream {
ret := ChunkStream{
Format: 0,
CSID: 2,
TypeID: id,
StreamID: 0,
Length: size,
Data: make([]byte, size),
}
utils.PutU32BE(ret.Data[:size], value)
return ret
}
const (
streamBegin uint32 = 0
streamEOF uint32 = 1
streamDry uint32 = 2
setBufferLen uint32 = 3
streamIsRecorded uint32 = 4
pingRequest uint32 = 6
pingResponse uint32 = 7
)
/*
+------------------------------+-------------------------
| Event Type ( 2- bytes ) | Event Data
+------------------------------+-------------------------
Pay load for the User Control Message.
*/
func (rtmpConn *RtmpConn) userControlMsg(eventType, buflen uint32) ChunkStream {
var ret ChunkStream
buflen += 2
ret = ChunkStream{
Format: 0,
CSID: 2,
TypeID: 4,
StreamID: 1,
Length: buflen,
Data: make([]byte, buflen),
}
ret.Data[0] = byte(eventType >> 8 & 0xff)
ret.Data[1] = byte(eventType & 0xff)
return ret
}
// SetBegin ...
func (rtmpConn *RtmpConn) SetBegin() {
ret := rtmpConn.userControlMsg(streamBegin, 4)
for i := 0; i < 4; i++ {
ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff)
}
rtmpConn.Write(&ret)
}
// SetRecorded ...
func (rtmpConn *RtmpConn) SetRecorded() {
ret := rtmpConn.userControlMsg(streamIsRecorded, 4)
for i := 0; i < 4; i++ {
ret.Data[2+i] = byte(1 >> uint32((3-i)*8) & 0xff)
}
rtmpConn.Write(&ret)
}
+124
View File
@@ -0,0 +1,124 @@
package protocol
import (
"bytes"
"fmt"
"io"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/amf"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
var (
STOP_CTRL = "RTMPRELAY_STOP"
)
type RtmpRelay struct {
PlayUrl string
PublishUrl string
cs_chan chan *core.ChunkStream
sndctrl_chan chan string
connectPlayClient *core.ConnClient
connectPublishClient *core.ConnClient
startflag bool
logger logger.Logger
}
func NewRtmpRelay(playurl *string, publishurl *string) *RtmpRelay {
return &RtmpRelay{
PlayUrl: *playurl,
PublishUrl: *publishurl,
cs_chan: make(chan *core.ChunkStream, 500),
sndctrl_chan: make(chan string),
connectPlayClient: nil,
connectPublishClient: nil,
startflag: false,
}
}
func (self *RtmpRelay) rcvPlayChunkStream() {
fmt.Printf("rcvPlayRtmpMediaPacket connectClient.Read...\n")
for {
if self.startflag == false {
self.connectPlayClient.Close()
fmt.Printf("rcvPlayChunkStream close: playurl=%s, publishurl=%s\n", self.PlayUrl, self.PublishUrl)
break
}
rc, err := self.connectPlayClient.Read()
if err != nil && err == io.EOF {
break
}
//glog.Infof("connectPlayClient.Read return rc.TypeID=%v length=%d, err=%v", rc.TypeID, len(rc.Data), err)
switch rc.TypeID {
case 20, 17:
r := bytes.NewReader(rc.Data)
vs, err := self.connectPlayClient.DecodeBatch(r, amf.AMF0)
fmt.Printf("rcvPlayRtmpMediaPacket: vs=%v, err=%v\n", vs, err)
case 18:
fmt.Printf("rcvPlayRtmpMediaPacket: metadata....\n")
case 8, 9:
self.cs_chan <- rc
}
}
}
func (self *RtmpRelay) sendPublishChunkStream() {
for {
select {
case rc := <-self.cs_chan:
//glog.Infof("sendPublishChunkStream: rc.TypeID=%v length=%d", rc.TypeID, len(rc.Data))
self.connectPublishClient.Write(rc)
case ctrlcmd := <-self.sndctrl_chan:
if ctrlcmd == STOP_CTRL {
self.connectPublishClient.Close()
fmt.Printf("sendPublishChunkStream close: playurl=%s, publishurl=%s\n", self.PlayUrl, self.PublishUrl)
break
}
}
}
}
// Start ...
func (self *RtmpRelay) Start() error {
if self.startflag {
return fmt.Errorf("The rtmprelay already started, playurl=%s, publishurl=%s", self.PlayUrl, self.PublishUrl)
}
self.connectPlayClient = core.NewConnClient(self.logger)
self.connectPublishClient = core.NewConnClient(self.logger)
self.logger.Debugf("Play server addr:%s starting....", self.PlayUrl)
err := self.connectPlayClient.Start(self.PlayUrl, "play")
if err != nil {
fmt.Printf("connectPlayClient.Start url=%v error\n", self.PlayUrl)
return err
}
fmt.Printf("publish server addr:%v starting....\n", self.PublishUrl)
err = self.connectPublishClient.Start(self.PublishUrl, "publish")
if err != nil {
fmt.Printf("connectPublishClient.Start url=%v error\n", self.PublishUrl)
self.connectPlayClient.Close()
return err
}
self.startflag = true
go self.rcvPlayChunkStream()
go self.sendPublishChunkStream()
return nil
}
func (self *RtmpRelay) Stop() {
if !self.startflag {
fmt.Printf("The rtmprelay already stoped, playurl=%s, publishurl=%s\n", self.PlayUrl, self.PublishUrl)
return
}
self.startflag = false
self.sndctrl_chan <- STOP_CTRL
}
+181
View File
@@ -0,0 +1,181 @@
package protocol
import (
"errors"
"fmt"
"sync"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/configure"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
type StaticPush struct {
RtmpUrl string
packet_chan chan *av.Packet
sndctrl_chan chan string
connectClient *core.ConnClient
startflag bool
logger logger.Logger
}
var G_StaticPushMap = make(map[string](*StaticPush))
var g_MapLock = new(sync.RWMutex)
var (
STATIC_RELAY_STOP_CTRL = "STATIC_RTMPRELAY_STOP"
)
func GetStaticPushList(appname string) ([]string, error) {
pushurlList, ok := configure.GetStaticPushUrlList(appname)
if !ok {
return nil, errors.New("no static push url")
}
return pushurlList, nil
}
func GetAndCreateStaticPushObject(rtmpurl string) *StaticPush {
g_MapLock.RLock()
staticpush, ok := G_StaticPushMap[rtmpurl]
fmt.Printf("GetAndCreateStaticPushObject: %s, return %v\n", rtmpurl, ok)
if !ok {
g_MapLock.RUnlock()
newStaticpush := NewStaticPush(rtmpurl)
g_MapLock.Lock()
G_StaticPushMap[rtmpurl] = newStaticpush
g_MapLock.Unlock()
return newStaticpush
}
g_MapLock.RUnlock()
return staticpush
}
func GetStaticPushObject(rtmpurl string) (*StaticPush, error) {
g_MapLock.RLock()
if staticpush, ok := G_StaticPushMap[rtmpurl]; ok {
g_MapLock.RUnlock()
return staticpush, nil
}
g_MapLock.RUnlock()
return nil, errors.New(fmt.Sprintf("G_StaticPushMap[%s] not exist...."))
}
func ReleaseStaticPushObject(rtmpurl string) {
g_MapLock.RLock()
if _, ok := G_StaticPushMap[rtmpurl]; ok {
g_MapLock.RUnlock()
fmt.Printf("ReleaseStaticPushObject %s ok\n", rtmpurl)
g_MapLock.Lock()
delete(G_StaticPushMap, rtmpurl)
g_MapLock.Unlock()
} else {
g_MapLock.RUnlock()
fmt.Printf("ReleaseStaticPushObject: not find %s\n", rtmpurl)
}
}
func NewStaticPush(rtmpurl string) *StaticPush {
return &StaticPush{
RtmpUrl: rtmpurl,
packet_chan: make(chan *av.Packet, 500),
sndctrl_chan: make(chan string),
connectClient: nil,
startflag: false,
}
}
func (self *StaticPush) Start() error {
if self.startflag {
return errors.New(fmt.Sprintf("StaticPush already start %s", self.RtmpUrl))
}
self.connectClient = core.NewConnClient(self.logger)
fmt.Printf("static publish server addr:%v starting....\n", self.RtmpUrl)
err := self.connectClient.Start(self.RtmpUrl, "publish")
if err != nil {
fmt.Printf("connectClient.Start url=%v error\n", self.RtmpUrl)
return err
}
fmt.Printf("static publish server addr:%v started, streamid=%d\n", self.RtmpUrl, self.connectClient.GetStreamID())
go self.HandleAvPacket()
self.startflag = true
return nil
}
func (self *StaticPush) Stop() {
if !self.startflag {
return
}
fmt.Printf("StaticPush Stop: %s\n", self.RtmpUrl)
self.sndctrl_chan <- STATIC_RELAY_STOP_CTRL
self.startflag = false
}
func (self *StaticPush) WriteAvPacket(packet *av.Packet) {
if !self.startflag {
return
}
self.packet_chan <- packet
}
func (self *StaticPush) sendPacket(p *av.Packet) {
if !self.startflag {
return
}
var cs core.ChunkStream
cs.Data = p.Data
cs.Length = uint32(len(p.Data))
cs.StreamID = self.connectClient.GetStreamID()
cs.Timestamp = p.TimeStamp
//cs.Timestamp += v.BaseTimeStamp()
//glog.Infof("Static sendPacket: rtmpurl=%s, length=%d, streamid=%d",
// self.RtmpUrl, len(p.Data), cs.StreamID)
switch p.PacketType {
case av.PacketTypeVideo:
cs.TypeID = av.TAG_VIDEO
case av.PacketTypeAudio:
cs.TypeID = av.TAG_AUDIO
case av.PacketTypeMetadata:
cs.TypeID = av.TAG_SCRIPTDATAAMF0
default:
}
self.connectClient.Write(&cs)
}
func (self *StaticPush) HandleAvPacket() {
if !self.IsStart() {
fmt.Printf("static push %s not started\n", self.RtmpUrl)
return
}
for {
select {
case packet := <-self.packet_chan:
self.sendPacket(packet)
case ctrlcmd := <-self.sndctrl_chan:
if ctrlcmd == STATIC_RELAY_STOP_CTRL {
self.connectClient.Close()
fmt.Printf("Static HandleAvPacket close: publishurl=%s\n", self.RtmpUrl)
break
}
}
}
}
func (self *StaticPush) IsStart() bool {
return self.startflag
}
+243
View File
@@ -0,0 +1,243 @@
package protocol
import (
"fmt"
"sync"
"time"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/cache"
"github.com/H0RlZ0N/gortmppush/utils"
)
// StreamInfo ...
type StreamInfo struct {
App string
Name string
URL string
}
// RtmpStream rtmp流类型
type RtmpStream struct {
streamID string
isStart bool
cache *cache.Cache
reader ReadCloser
writers []WriteCloser
streamInfo StreamInfo
pktChan chan *av.Packet
writerChan chan WriteCloser
readerChan chan ReadCloser
streamHandler *StreamHandler
logger logger.Logger
}
// PackWriterCloser packet写对象结构
type PackWriterCloser struct {
init bool
w WriteCloser
}
// NewWriter 创建新的写对象
func (p *PackWriterCloser) NewWriter() (WriteCloser, error) {
return p.w, nil
}
// NewStream 创建新的rtmp流
func NewStream(streamInfo StreamInfo, handler *StreamHandler, log logger.Logger) *RtmpStream {
return &RtmpStream{
streamID: utils.NewId(),
streamInfo: streamInfo,
cache: cache.NewCache(),
streamHandler: handler,
writers: make([]WriteCloser, 0),
writerChan: make(chan WriteCloser, 1),
readerChan: make(chan ReadCloser, 1),
pktChan: make(chan *av.Packet, 16),
logger: log,
}
}
// ID 获取rtmp流id
func (s *RtmpStream) ID() string {
if s.reader != nil {
return s.streamID
}
return ""
}
// GetReader 获取rtmp流读对象
func (s *RtmpStream) GetReader() ReadCloser {
return s.reader
}
// AddReader 为rtmp流对象添加一个读对象
func (s *RtmpStream) AddReader(r ReadCloser) error {
go func() {
s.readerChan <- r
}()
return nil
}
// AddWriter 为rtmp流对象添加一个写对象
func (s *RtmpStream) AddWriter(w WriteCloser) error {
go func() {
s.writerChan <- w
}()
return nil
}
// 开始读取流数据
func (s *RtmpStream) startRead(wg *sync.WaitGroup) {
s.logger.Infof("Start to read data, id:%s", s.streamID)
wg.Add(1)
defer wg.Done()
for {
pkt := &av.Packet{}
if err := s.reader.Read(pkt); err != nil {
s.logger.Errorf("Read pkt failed, %s", err.Error())
return
}
//先缓存数据包
s.cache.Write(pkt)
select {
case s.pktChan <- pkt:
default:
}
}
}
// 转发流数据
func (s *RtmpStream) streamLoop() {
s.logger.Infof("Start stream loop, %s", s.streamID)
checkTicker := time.NewTicker(time.Second * 30)
defer func() {
streamKey := fmt.Sprintf("%s_%s", s.streamInfo.App, s.streamInfo.Name)
s.streamHandler.remove(streamKey)
s.close()
checkTicker.Stop()
s.logger.Infof("Rtmp stream[%s] exit.", s.streamID)
}()
var wg sync.WaitGroup
lastWriteRemove := time.Now()
for {
select {
case pkt := <-s.pktChan:
{
bRemove := false
for i, w := range s.writers {
if err := w.Write(pkt); err != nil {
s.logger.Infof("Write packet failed, %s close writer.", err.Error())
w.Close() //todo 是否要传递参数
s.writers[i] = nil
bRemove = true
}
}
if bRemove {
for i := 0; i < len(s.writers); {
if s.writers[i] == nil {
s.writers = append(s.writers[:i], s.writers[i+1:]...)
} else {
i++
}
}
lastWriteRemove = time.Now()
}
}
case w := <-s.writerChan: // 接收到play消息
{
//TODO 这个方法不是很好,先这样,后续再优化
sw, ok := w.(*StreamWriter)
if ok == false {
s.logger.Errorf("can not cast writerclose to streamwriter")
w.Close()
return
}
if err := s.cache.Send(sw.packetQueue); err != nil {
s.logger.Errorf("Send cache failed, %s", err.Error())
w.Close()
return
}
s.writers = append(s.writers, w)
}
case r := <-s.readerChan: // 接收到push消息
{
if s.reader != nil {
s.reader.Close()
wg.Wait() //等待读取数据协程结束
//清除pktChan中的
CleanLoop:
for {
select {
case <-s.pktChan:
default:
break CleanLoop
}
}
//更新一下基本时间戳,保证每个writer的时间戳都是递增的
for _, w := range s.writers {
w.CalcBaseTimestamp()
}
}
s.reader = r
go s.startRead(&wg)
}
case <-checkTicker.C:
{
//检查是否有writer,没有则释放
if len(s.writers) == 0 && time.Now().Sub(lastWriteRemove) >= 30 {
s.logger.Debug("Stream no play...")
//return
}
//检查是否有reader
if s.reader == nil || !s.reader.Alive() {
s.logger.Debugf("Stream reader is nil(%v) or not alive, exit", s.reader == nil)
return
}
//检查每个writer是否超时
for i := 0; i < len(s.writers); {
w := s.writers[i]
if !w.Alive() {
s.writers = append(s.writers[:i], s.writers[i+1:]...)
w.Close() //todo 是否要传递关闭原因
lastWriteRemove = time.Now()
} else {
i++
}
}
}
}
}
}
func (s *RtmpStream) close() {
if s.reader != nil {
s.reader.Close()
s.logger.Infof("[%s] publish closed.", s.streamID)
}
//可能writerChan或readerChan中有未处理的writer和reader
//读取出来,并关闭
CloseLoop:
for {
select {
case w := <-s.writerChan:
w.Close()
case r := <-s.readerChan:
r.Close()
default:
break CloseLoop
}
}
for _, writer := range s.writers {
writer.Close()
}
}
+86
View File
@@ -0,0 +1,86 @@
package protocol
import (
"fmt"
"sync"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
// StreamHandler 管理RtmpStream,每个RtmpStream代表一路流
type StreamHandler struct {
mutex sync.Mutex
logger logger.Logger
streams map[string]*RtmpStream
}
// NewStreamHandler 创建一个管理RtmpStream的Handler
func NewStreamHandler(log logger.Logger) *StreamHandler {
handler := &StreamHandler{
logger: log,
streams: make(map[string]*RtmpStream),
}
return handler
}
// get rtmp stream, if not exist, create a new one
// bool indicate weathe the stream is new, true-new false-not
func (h *StreamHandler) getOrCreate(streamInfo StreamInfo) *RtmpStream {
h.mutex.Lock()
defer h.mutex.Unlock()
streamKey := fmt.Sprintf("%s_%s", streamInfo.App, streamInfo.Name)
if stream, ok := h.streams[streamKey]; ok {
return stream
}
stream := NewStream(streamInfo, h, h.logger)
h.streams[streamKey] = stream
go stream.streamLoop()
h.logger.Infof("Create new stream, id:%s app:%s name:%s", stream.streamID,
streamInfo.App, streamInfo.Name)
return stream
}
func (h *StreamHandler) remove(key string) {
h.mutex.Lock()
defer h.mutex.Unlock()
if _, ok := h.streams[key]; ok {
delete(h.streams, key)
}
}
// GetStreams 获取所有的流
func (h *StreamHandler) GetStreams() []*RtmpStream {
streams := make([]*RtmpStream, 0)
h.mutex.Lock()
defer h.mutex.Unlock()
for _, v := range h.streams {
streams = append(streams, v)
}
return streams
}
// HandleConnect ...
func (h *StreamHandler) HandleConnect(conn *core.ForwardConnect) error {
app, name, url := conn.GetStreamInfo()
streamInfo := StreamInfo{
App: app,
Name: name,
URL: url,
}
stream := h.getOrCreate(streamInfo)
if conn.IsPublisher() {
reader := NewStreamReader(conn, stream.ID(), h.logger)
if err := stream.AddReader(reader); err != nil {
return fmt.Errorf("Add stream reader failed, %v", err)
}
} else {
writer := NewStreamWriter(conn, stream.ID(), h.logger)
if err := stream.AddWriter(writer); err != nil {
return fmt.Errorf("Add stream writer failed, %v", err)
}
}
return nil
}
+316
View File
@@ -0,0 +1,316 @@
package protocol
import (
"errors"
"fmt"
"time"
"github.com/H0RlZ0N/gortmppush/av"
"github.com/H0RlZ0N/gortmppush/container/flv"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
const (
maxQueueNum = 1024
saveStaticsInterval = 5000
)
// ReadCloser ...
type ReadCloser interface {
Close()
Alive() bool
Read(*av.Packet) error
}
// WriteCloser ...
type WriteCloser interface {
Close()
Alive() bool
CalcBaseTimestamp()
Write(*av.Packet) error
}
// StaticsBW todo comment
type StaticsBW struct {
StreamID uint32
VideoDatainBytes uint64
LastVideoDatainBytes uint64
VideoSpeedInBytesperMS uint64
AudioDatainBytes uint64
LastAudioDatainBytes uint64
AudioSpeedInBytesperMS uint64
LastTimestamp int64
}
// StreamWriter 是代表rtmp连接的写入对象
type StreamWriter struct {
av.RWBaser
streamID string
closed bool
keyframeNeed bool
conn *core.ForwardConnect
packetQueue chan *av.Packet
WriteBWInfo StaticsBW
logger logger.Logger
}
// NewStreamWriter 创建一个新的写入对象
func NewStreamWriter(conn *core.ForwardConnect, streamID string, log logger.Logger) *StreamWriter {
writer := &StreamWriter{
streamID: streamID,
conn: conn,
RWBaser: av.NewRWBaser(time.Second * 10),
packetQueue: make(chan *av.Packet, maxQueueNum),
WriteBWInfo: StaticsBW{0, 0, 0, 0, 0, 0, 0, 0},
logger: log,
keyframeNeed: true,
}
//todo 这个是否有必要先检查一下读写情况
go writer.Check()
go func() {
err := writer.SendPacket()
if err != nil {
writer.logger.Errorf("SendPacket failed, %s", err.Error())
}
}()
return writer
}
// SaveStatics 保存统计信息
func (sw *StreamWriter) SaveStatics(streamid uint32, length uint64, isVideoFlag bool) {
nowInMS := int64(time.Now().UnixNano() / 1e6)
sw.WriteBWInfo.StreamID = streamid
if isVideoFlag {
sw.WriteBWInfo.VideoDatainBytes = sw.WriteBWInfo.VideoDatainBytes + length
} else {
sw.WriteBWInfo.AudioDatainBytes = sw.WriteBWInfo.AudioDatainBytes + length
}
if sw.WriteBWInfo.LastTimestamp == 0 {
sw.WriteBWInfo.LastTimestamp = nowInMS
} else if (nowInMS - sw.WriteBWInfo.LastTimestamp) >= saveStaticsInterval {
diffTimestamp := (nowInMS - sw.WriteBWInfo.LastTimestamp) / 1000
sw.WriteBWInfo.VideoSpeedInBytesperMS = (sw.WriteBWInfo.VideoDatainBytes - sw.WriteBWInfo.LastVideoDatainBytes) * 8 / uint64(diffTimestamp) / 1000
sw.WriteBWInfo.AudioSpeedInBytesperMS = (sw.WriteBWInfo.AudioDatainBytes - sw.WriteBWInfo.LastAudioDatainBytes) * 8 / uint64(diffTimestamp) / 1000
sw.WriteBWInfo.LastVideoDatainBytes = sw.WriteBWInfo.VideoDatainBytes
sw.WriteBWInfo.LastAudioDatainBytes = sw.WriteBWInfo.AudioDatainBytes
sw.WriteBWInfo.LastTimestamp = nowInMS
}
}
// Check 连接状态检测
func (sw *StreamWriter) Check() {
for {
_, err := sw.conn.Read()
if err != nil {
sw.Close()
return
}
}
}
// Write ...
func (sw *StreamWriter) Write(p *av.Packet) (err error) {
if sw.closed {
err = errors.New("PeerWriter closed")
return
}
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("Panic %v", e)
}
}()
if p.PacketType == av.PacketTypeVideo {
if sw.keyframeNeed {
if p.VHeader.FrameType != av.FRAME_KEY {
sw.logger.Warn("Key frame need.")
return
}
sw.keyframeNeed = false
}
}
select {
case sw.packetQueue <- p:
default:
if p.PacketType == av.PacketTypeVideo && p.VHeader.FrameType == av.FRAME_KEY {
sw.keyframeNeed = true
}
sw.logger.Warn("packet droped...")
}
return
}
// SendPacket todo comment
func (sw *StreamWriter) SendPacket() error {
var cs core.ChunkStream
for {
p, ok := <-sw.packetQueue
if ok {
cs.Data = p.Data
cs.Length = uint32(len(p.Data))
cs.StreamID = p.StreamID
cs.Timestamp = p.TimeStamp
cs.Timestamp += sw.BaseTimeStamp()
isVideo := false
switch p.PacketType {
case av.PacketTypeVideo:
cs.TypeID = av.TAG_VIDEO
isVideo = true
case av.PacketTypeAudio:
cs.TypeID = av.TAG_AUDIO
case av.PacketTypeMetadata:
cs.TypeID = av.TAG_SCRIPTDATAAMF0
}
sw.SaveStatics(p.StreamID, uint64(cs.Length), isVideo)
sw.SetPreTime()
sw.RecTimeStamp(cs.Timestamp, cs.TypeID)
err := sw.conn.Write(cs)
if err != nil {
sw.closed = true
return err
}
sw.conn.Flush()
} else {
return errors.New("closed")
}
}
}
//StreamInfo todo comment
// func (sw *StreamWriter) StreamInfo() (ret av.StreamInfo) {
// ret.UID = sw.UID
// _, _, URL := sw.conn.GetStreamInfo()
// ret.URL = URL
// _url, err := url.Parse(URL)
// if err != nil {
// fmt.Printf("Parse url failed, url:%s err:%v\n", URL, err)
// }
// ret.Key = strings.TrimLeft(_url.Path, "/")
// ret.Inter = true
// return
// }
// Close todo comment
func (sw *StreamWriter) Close() {
if !sw.closed {
close(sw.packetQueue)
}
sw.closed = true
sw.conn.Close()
}
// StreamReader todo comment
type StreamReader struct {
av.RWBaser
streamID string
demuxer *flv.Demuxer
conn *core.ForwardConnect
ReadBWInfo StaticsBW
logger logger.Logger
}
// NewStreamReader 创建一个rtmp连接读对象
func NewStreamReader(conn *core.ForwardConnect, streamID string, log logger.Logger) *StreamReader {
return &StreamReader{
streamID: streamID,
conn: conn,
RWBaser: av.NewRWBaser(time.Second * 10),
demuxer: flv.NewDemuxer(),
ReadBWInfo: StaticsBW{0, 0, 0, 0, 0, 0, 0, 0},
logger: log,
}
}
// SaveStatics todo comment
func (pr *StreamReader) SaveStatics(streamid uint32, length uint64, isVideoFlag bool) {
nowInMS := int64(time.Now().UnixNano() / 1e6)
pr.ReadBWInfo.StreamID = streamid
if isVideoFlag {
pr.ReadBWInfo.VideoDatainBytes = pr.ReadBWInfo.VideoDatainBytes + length
} else {
pr.ReadBWInfo.AudioDatainBytes = pr.ReadBWInfo.AudioDatainBytes + length
}
if pr.ReadBWInfo.LastTimestamp == 0 {
pr.ReadBWInfo.LastTimestamp = nowInMS
} else if (nowInMS - pr.ReadBWInfo.LastTimestamp) >= saveStaticsInterval {
diffTimestamp := (nowInMS - pr.ReadBWInfo.LastTimestamp) / 1000
//glog.Infof("now=%d, last=%d, diff=%d", nowInMS, v.ReadBWInfo.LastTimestamp, diffTimestamp)
pr.ReadBWInfo.VideoSpeedInBytesperMS = (pr.ReadBWInfo.VideoDatainBytes - pr.ReadBWInfo.LastVideoDatainBytes) * 8 / uint64(diffTimestamp) / 1000
pr.ReadBWInfo.AudioSpeedInBytesperMS = (pr.ReadBWInfo.AudioDatainBytes - pr.ReadBWInfo.LastAudioDatainBytes) * 8 / uint64(diffTimestamp) / 1000
pr.ReadBWInfo.LastVideoDatainBytes = pr.ReadBWInfo.VideoDatainBytes
pr.ReadBWInfo.LastAudioDatainBytes = pr.ReadBWInfo.AudioDatainBytes
pr.ReadBWInfo.LastTimestamp = nowInMS
}
}
func (pr *StreamReader) Read(p *av.Packet) (err error) {
defer func() {
if r := recover(); r != nil {
fmt.Printf("rtmp read packet panic: %v\n", r)
}
}()
pr.SetPreTime()
var cs *core.ChunkStream
for {
if cs, err = pr.conn.Read(); err != nil {
return err
}
if cs.TypeID == av.TAG_AUDIO ||
cs.TypeID == av.TAG_VIDEO ||
cs.TypeID == av.TAG_SCRIPTDATAAMF0 ||
cs.TypeID == av.TAG_SCRIPTDATAAMF3 {
break
}
}
isVideo := false
switch cs.TypeID {
case av.TAG_VIDEO:
p.PacketType = av.PacketTypeVideo
isVideo = true
case av.TAG_AUDIO:
p.PacketType = av.PacketTypeAudio
case av.TAG_SCRIPTDATAAMF0, av.TAG_SCRIPTDATAAMF3:
p.PacketType = av.PacketTypeMetadata
}
p.StreamID = cs.StreamID
p.Data = cs.Data
p.TimeStamp = cs.Timestamp
pr.SaveStatics(p.StreamID, uint64(len(p.Data)), isVideo)
pr.demuxer.DemuxH(p)
return err
}
//StreamInfo 返回信息
// func (pr *StreamReader) StreamInfo() (ret av.StreamInfo) {
// ret.UID = pr.UID
// _, _, URL := pr.conn.GetStreamInfo()
// ret.URL = URL
// _url, err := url.Parse(URL)
// if err != nil {
// fmt.Printf("Parse url failed, url:%s err:%v\n", URL, err)
// }
// ret.Key = strings.TrimLeft(_url.Path, "/")
// return
// }
// Close 关闭读对象
func (pr *StreamReader) Close() {
pr.conn.Close()
}
+187
View File
@@ -0,0 +1,187 @@
package gortmppush
import (
"crypto/tls"
"fmt"
"net"
"time"
"github.com/H0RlZ0N/gortmppush/logger"
"github.com/H0RlZ0N/gortmppush/protocol"
"github.com/H0RlZ0N/gortmppush/protocol/core"
)
// Server rtmpfuwu
type Server struct {
handler *protocol.StreamHandler
logger logger.Logger
}
// NewRtmpServer 创建一个rtmp服务
func NewRtmpServer(h *protocol.StreamHandler, log logger.Logger) *Server {
return &Server{
handler: h,
logger: log,
}
}
// Serve 启动rtmp监听服务
func (s *Server) Serve(listenAddr string) (err error) {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("rtmp server panic:%v", r)
}
}()
var listener net.Listener
listener, err = net.Listen("tcp", listenAddr)
if err != nil {
err = fmt.Errorf("net.Listen failed, %v", err)
}
s.logger.Infof("Start rtmp server, listen on:%s", listenAddr)
for {
var netconn net.Conn
netconn, err = listener.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
//如果时临时错误,sleep一段时间继续
s.logger.Warn("Accept failed, temporary error, try again...")
time.Sleep(time.Millisecond * 100)
continue
}
s.logger.Errorf("Accept failed, err:%s", err.Error())
return
}
rtmpConn := core.NewRtmpConn(netconn, 4*1024)
s.logger.Infof("New rtmp connect, remote:%s local:%s",
rtmpConn.RemoteAddr().String(), rtmpConn.LocalAddr().String())
go s.handleConn(rtmpConn)
}
}
// ServeTLS 启动监听rtmp tls连接
func (s *Server) ServeTLS(listenAddr string, tlsCrt, tlsKey string) error {
defer func() {
if r := recover(); r != nil {
s.logger.Errorf("rtmps server panic:%v", r)
}
}()
cert, err := tls.LoadX509KeyPair(tlsCrt, tlsKey)
if err != nil {
return fmt.Errorf("tls.LoadX509KeyPair failed, %s", err.Error())
}
var listener net.Listener
config := &tls.Config{Certificates: []tls.Certificate{cert}}
listener, err = tls.Listen("tcp", listenAddr, config)
if err != nil {
return fmt.Errorf("Listen rtsp tls failed, %s", err.Error())
}
s.logger.Infof("Start rtmps server, listen on:%s", listenAddr)
for {
var netconn net.Conn
netconn, err = listener.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
//如果时临时错误,sleep一段时间继续
s.logger.Warn("Accept failed, temporary error, try again...")
time.Sleep(time.Millisecond * 100)
continue
}
return fmt.Errorf("Accept failed, %s", err.Error())
}
rtmpConn := core.NewRtmpConn(netconn, 4*1024)
s.logger.Infof("New rtmp connect, remote:%s local:%s",
rtmpConn.RemoteAddr().String(), rtmpConn.LocalAddr().String())
go s.handleConn(rtmpConn)
}
}
func (s *Server) handleConn(rtmpConn *core.RtmpConn) {
var err error
defer func() {
if err != nil {
rtmpConn.Close()
}
}()
if err = rtmpConn.HandshakeServer(); err != nil {
s.logger.Errorf("HandshakeServer failed, %s", err.Error())
return
}
//创建一个服务端连接
forwardConn := core.NewForwardConnect(rtmpConn, s.logger)
if err = forwardConn.SetUpPlayOrPublish(); err != nil {
s.logger.Errorf("SetUpPlayOrPublish failed, %s", err.Error())
return
}
//根据appname判断流是否存在
//如果是publish,如果对应的流已经存在,则关闭,重新创建
//如果是play,如果对应的流不存在,返回错误
if err = s.handler.HandleConnect(forwardConn); err != nil {
s.logger.Errorf("Handle connect failed, %v", err)
return
}
s.logger.Infof("Receive new connection, publisher:%v", forwardConn.IsPublisher())
}
// func (s *Server) handleConn1(rtmpConn *core.RtmpConn) {
// var err error
// defer func() {
// if err != nil {
// rtmpConn.Close()
// }
// }()
// if err = rtmpConn.HandshakeServer(); err != nil {
// fmt.Printf("HandshakeServer failed, %v\n", err)
// return
// }
// clientConn := core.NewClientConn(rtmpConn)
// if err = clientConn.ReadMsg(); err != nil {
// fmt.Printf("handleConn read msg error, %v\n", err)
// return
// }
// appname, _, _ := clientConn.GetStreamInfo()
// if ret := configure.CheckAppName(appname); !ret {
// err = fmt.Errorf("application name=%s is not configured", appname)
// s.logger.Errorf("CheckAppName failed, name:%s %v", appname, err)
// return
// }
// fmt.Printf("handleConn: IsPublisher=%v\n", clientConn.IsPublisher())
// if clientConn.IsPublisher() {
// if pushlist, ret := configure.GetStaticPushUrlList(appname); ret && (pushlist != nil) {
// s.logger.Infof("GetStaticPushUrlList: %v", pushlist)
// }
// reader := protocol.NewStreamReader(clientConn, s.logger)
// s.handler.HandleReader(reader)
// fmt.Printf("new publisher: %+v\n", reader.StreamInfo())
// if s.extendWriter != nil {
// writeType := reflect.TypeOf(s.extendWriter)
// fmt.Printf("handleConn:writeType=%v\n", writeType)
// writer, err := s.extendWriter.NewWriter(reader.StreamInfo())
// if err != nil {
// fmt.Printf("s.extendWriter.NewWriter failed, %v\n", err)
// return
// }
// s.handler.HandleWriter(writer)
// }
// flvDvr := &flv.FlvDvr{}
// flvWriter, err := flvDvr.NewWriter(reader.StreamInfo())
// if err != nil {
// fmt.Printf("Create flv writer failed, %v\n", err)
// } else {
// s.handler.HandleWriter(flvWriter)
// }
// } else {
// writer := protocol.NewStreamWriter(clientConn, s.logger)
// fmt.Printf("new player: %+v\n", writer.StreamInfo())
// s.handler.HandleWriter(writer)
// }
// return
// }
+26
View File
@@ -0,0 +1,26 @@
package gortmppush
import "github.com/H0RlZ0N/gortmppush/logger"
// SettingFunc ...
type SettingFunc func(*SettingEngine)
// SettingEngine ...
type SettingEngine struct {
loggerFactory logger.LoggerFactory
logLevel logger.LogLevel
}
// WithLoggerFactory 设置日志创建类
func WithLoggerFactory(v logger.LoggerFactory) SettingFunc {
return func(setting *SettingEngine) {
setting.loggerFactory = v
}
}
// WithLogLevel 设置日志等级
func WithLogLevel(v logger.LogLevel) SettingFunc {
return func(setting *SettingEngine) {
setting.logLevel = v
}
}
+3
View File
@@ -0,0 +1,3 @@
package utils
var RecommendBufioSize = 1024 * 64
+24
View File
@@ -0,0 +1,24 @@
package utils
type Pool struct {
pos int
buf []byte
}
const maxpoolsize = 50 * 1024
func (pool *Pool) Get(size int) []byte {
if maxpoolsize-pool.pos < size {
pool.pos = 0
pool.buf = make([]byte, maxpoolsize)
}
b := pool.buf[pool.pos : pool.pos+size]
pool.pos += size
return b
}
func NewPool() *Pool {
return &Pool{
buf: make([]byte, maxpoolsize),
}
}
+72
View File
@@ -0,0 +1,72 @@
package utils
import (
"sync"
"github.com/H0RlZ0N/gortmppush/av"
)
// Queue is a basic FIFO queue for Messages.
type Queue struct {
maxSize int
list []*av.Packet
mutex sync.Mutex
}
// NewQueue returns a new Queue. If maxSize is greater than zero the queue will
// not grow more than the defined size.
func NewQueue(maxSize int) *Queue {
return &Queue{
maxSize: maxSize,
}
}
// Push adds a message to the queue.
func (q *Queue) Push(msg *av.Packet) {
q.mutex.Lock()
defer q.mutex.Unlock()
if len(q.list) == q.maxSize {
q.pop()
}
q.list = append(q.list, msg)
}
// Pop removes and returns a message from the queue in first to last order.
func (q *Queue) Pop() *av.Packet {
q.mutex.Lock()
defer q.mutex.Unlock()
if len(q.list) == 0 {
return nil
}
return q.pop()
}
func (q *Queue) pop() *av.Packet {
x := len(q.list) - 1
msg := q.list[x]
q.list = q.list[:x]
return msg
}
// Len returns the length of the queue.
func (q *Queue) Len() int {
q.mutex.Lock()
defer q.mutex.Unlock()
return len(q.list)
}
// All returns and removes all messages from the queue.
func (q *Queue) All() []*av.Packet {
q.mutex.Lock()
defer q.mutex.Unlock()
cache := q.list
q.list = nil
return cache
}
+121
View File
@@ -0,0 +1,121 @@
package utils
func U8(b []byte) (i uint8) {
return b[0]
}
func U16BE(b []byte) (i uint16) {
i = uint16(b[0])
i <<= 8
i |= uint16(b[1])
return
}
func I16BE(b []byte) (i int16) {
i = int16(b[0])
i <<= 8
i |= int16(b[1])
return
}
func I24BE(b []byte) (i int32) {
i = int32(int8(b[0]))
i <<= 8
i |= int32(b[1])
i <<= 8
i |= int32(b[2])
return
}
func U24BE(b []byte) (i uint32) {
i = uint32(b[0])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[2])
return
}
func I32BE(b []byte) (i int32) {
i = int32(int8(b[0]))
i <<= 8
i |= int32(b[1])
i <<= 8
i |= int32(b[2])
i <<= 8
i |= int32(b[3])
return
}
func U32LE(b []byte) (i uint32) {
i = uint32(b[3])
i <<= 8
i |= uint32(b[2])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[0])
return
}
func U32BE(b []byte) (i uint32) {
i = uint32(b[0])
i <<= 8
i |= uint32(b[1])
i <<= 8
i |= uint32(b[2])
i <<= 8
i |= uint32(b[3])
return
}
func U40BE(b []byte) (i uint64) {
i = uint64(b[0])
i <<= 8
i |= uint64(b[1])
i <<= 8
i |= uint64(b[2])
i <<= 8
i |= uint64(b[3])
i <<= 8
i |= uint64(b[4])
return
}
func U64BE(b []byte) (i uint64) {
i = uint64(b[0])
i <<= 8
i |= uint64(b[1])
i <<= 8
i |= uint64(b[2])
i <<= 8
i |= uint64(b[3])
i <<= 8
i |= uint64(b[4])
i <<= 8
i |= uint64(b[5])
i <<= 8
i |= uint64(b[6])
i <<= 8
i |= uint64(b[7])
return
}
func I64BE(b []byte) (i int64) {
i = int64(int8(b[0]))
i <<= 8
i |= int64(b[1])
i <<= 8
i |= int64(b[2])
i <<= 8
i |= int64(b[3])
i <<= 8
i |= int64(b[4])
i <<= 8
i |= int64(b[5])
i <<= 8
i |= int64(b[6])
i <<= 8
i |= int64(b[7])
return
}
+12
View File
@@ -0,0 +1,12 @@
package utils
import (
uuid "github.com/satori/go.uuid"
)
func NewId() string {
id := uuid.NewV1()
//b64 := base64.URLEncoding.EncodeToString(id.Bytes()[:12])
//return b64
return id.String()
}
+87
View File
@@ -0,0 +1,87 @@
package utils
func PutU8(b []byte, v uint8) {
b[0] = v
}
func PutI16BE(b []byte, v int16) {
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func PutU16BE(b []byte, v uint16) {
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func PutI24BE(b []byte, v int32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
func PutU24BE(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
func PutI32BE(b []byte, v int32) {
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func PutU32BE(b []byte, v uint32) {
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func PutU32LE(b []byte, v uint32) {
b[3] = byte(v >> 24)
b[2] = byte(v >> 16)
b[1] = byte(v >> 8)
b[0] = byte(v)
}
func PutU40BE(b []byte, v uint64) {
b[0] = byte(v >> 32)
b[1] = byte(v >> 24)
b[2] = byte(v >> 16)
b[3] = byte(v >> 8)
b[4] = byte(v)
}
func PutU48BE(b []byte, v uint64) {
b[0] = byte(v >> 40)
b[1] = byte(v >> 32)
b[2] = byte(v >> 24)
b[3] = byte(v >> 16)
b[4] = byte(v >> 8)
b[5] = byte(v)
}
func PutU64BE(b []byte, v uint64) {
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func PutI64BE(b []byte, v int64) {
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}