Files
monibuca/plugin/rtp/pkg/forward.go
T
langhuihui 8a9fffb987 refactor: frame converter and mp4 track improvements
- Refactor frame converter implementation
- Update mp4 track to use ICodex
- General refactoring and code improvements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-28 19:55:37 +08:00

477 lines
11 KiB
Go

package rtp
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/pion/rtp"
"m7s.live/v5/pkg/util"
)
// ConnectionConfig 连接配置
type ConnectionConfig struct {
IP string
Port uint32
Mode StreamMode
SSRC uint32 // RTP SSRC
}
// ForwardConfig 转发配置
type ForwardConfig struct {
Source ConnectionConfig
Target ConnectionConfig
Relay bool
}
// Forwarder 转发器
type Forwarder struct {
config *ForwardConfig
source net.Conn
target net.Conn
}
// NewForwarder 创建新的转发器
func NewForwarder(config *ForwardConfig) *Forwarder {
return &Forwarder{
config: config,
}
}
// establishSourceConnection 建立源连接
func (f *Forwarder) establishSourceConnection(config ConnectionConfig) (net.Conn, error) {
switch config.Mode {
case StreamModeTCPActive:
dialer := &net.Dialer{Timeout: 10 * time.Second}
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("connect failed: %v", err)
}
return netConn, nil
case StreamModeTCPPassive:
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("listen failed: %v", err)
}
// Set timeout for accepting connections
if tcpListener, ok := listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
}
netConn, err := listener.Accept()
if err != nil {
listener.Close()
return nil, fmt.Errorf("accept failed: %v", err)
}
return netConn, nil
case StreamModeUDP:
// Source UDP - listen
udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("resolve UDP address failed: %v", err)
}
netConn, err := net.ListenUDP("udp4", udpAddr)
if err != nil {
return nil, fmt.Errorf("UDP listen failed: %v", err)
}
return netConn, nil
}
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
}
// establishTargetConnection 建立目标连接
func (f *Forwarder) establishTargetConnection(config ConnectionConfig) (net.Conn, error) {
switch config.Mode {
case StreamModeTCPActive:
dialer := &net.Dialer{Timeout: 10 * time.Second}
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("connect failed: %v", err)
}
return netConn, nil
case StreamModeTCPPassive:
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("listen failed: %v", err)
}
// Set timeout for accepting connections
if tcpListener, ok := listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
}
netConn, err := listener.Accept()
if err != nil {
listener.Close()
return nil, fmt.Errorf("accept failed: %v", err)
}
return netConn, nil
case StreamModeUDP:
// Target UDP - dial
netConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: net.ParseIP(config.IP),
Port: int(config.Port),
})
if err != nil {
return nil, fmt.Errorf("UDP dial failed: %v", err)
}
return netConn, nil
}
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
}
// setupConnections 建立源和目标连接
func (f *Forwarder) setupConnections() error {
var err error
// 建立源连接
f.source, err = f.establishSourceConnection(f.config.Source)
if err != nil {
return fmt.Errorf("source connection failed: %v", err)
}
// 建立目标连接
f.target, err = f.establishTargetConnection(f.config.Target)
if err != nil {
return fmt.Errorf("target connection failed: %v", err)
}
return nil
}
// cleanup 清理连接
func (f *Forwarder) cleanup() {
if f.source != nil {
f.source.Close()
}
if f.target != nil {
f.target.Close()
}
}
// createRTPReader 创建RTP读取器
func (f *Forwarder) createRTPReader() IRTPReader {
switch f.config.Source.Mode {
case StreamModeUDP:
return NewRTPUDPReader(f.source)
case StreamModeTCPActive, StreamModeTCPPassive:
return NewRTPTCPReader(f.source)
default:
return nil
}
}
// createRTPWriter 创建RTP写入器
func (f *Forwarder) createRTPWriter() RTPWriter {
return NewRTPWriter(f.target, f.config.Target.Mode)
}
// RTPWriter RTP写入器接口
type RTPWriter interface {
WritePacket(packet *rtp.Packet) error
WriteRaw(data []byte) error
}
// rtpWriter RTP写入器实现
type rtpWriter struct {
writer io.Writer
mode StreamMode
header []byte
sendBuffer util.Buffer // 可复用的发送缓冲区
}
// NewRTPWriter 创建RTP写入器
func NewRTPWriter(writer io.Writer, mode StreamMode) RTPWriter {
return &rtpWriter{
writer: writer,
mode: mode,
header: make([]byte, 2),
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
}
}
// WritePacket 写入RTP包
func (w *rtpWriter) WritePacket(packet *rtp.Packet) error {
// 复用sendBuffer,避免重复创建
w.sendBuffer.Reset()
w.sendBuffer.Malloc(packet.MarshalSize())
_, err := packet.MarshalTo(w.sendBuffer)
if err != nil {
return fmt.Errorf("marshal RTP packet failed: %v", err)
}
return w.WriteRaw(w.sendBuffer)
}
// WriteRaw 写入原始数据
func (w *rtpWriter) WriteRaw(data []byte) error {
if w.mode == StreamModeUDP {
_, err := w.writer.Write(data)
return err
} else {
// TCP模式需要添加长度头
binary.BigEndian.PutUint16(w.header, uint16(len(data)))
_, err := w.writer.Write(w.header)
if err != nil {
return err
}
_, err = w.writer.Write(data)
return err
}
}
// RelayProcessor 中继处理器
type RelayProcessor struct {
reader io.Reader
writer io.Writer
sourceMode StreamMode
targetMode StreamMode
buffer []byte // 可复用的缓冲区
header []byte // 可复用的头部缓冲区
}
// NewRelayProcessor 创建中继处理器
func NewRelayProcessor(reader io.Reader, writer io.Writer, sourceMode, targetMode StreamMode) *RelayProcessor {
return &RelayProcessor{
reader: reader,
writer: writer,
sourceMode: sourceMode,
targetMode: targetMode,
buffer: make([]byte, 1460), // 初始化可复用缓冲区
header: make([]byte, 2), // 初始化可复用头部缓冲区
}
}
// Process 处理中继
func (p *RelayProcessor) Process(ctx context.Context) error {
if p.sourceMode == p.targetMode {
// 相同模式直接复制
_, err := io.Copy(p.writer, p.reader)
return err
}
// 不同模式需要转换
if p.sourceMode == StreamModeUDP && (p.targetMode == StreamModeTCPActive || p.targetMode == StreamModeTCPPassive) {
// UDP to TCP
return p.processUDPToTCP(ctx)
} else if (p.sourceMode == StreamModeTCPActive || p.sourceMode == StreamModeTCPPassive) && p.targetMode == StreamModeUDP {
// TCP to UDP
return p.processTCPToUDP(ctx)
}
return fmt.Errorf("unsupported mode combination")
}
// processUDPToTCP UDP转TCP
func (p *RelayProcessor) processUDPToTCP(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
n, err := p.reader.Read(p.buffer)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
// 添加2字节长度头
binary.BigEndian.PutUint16(p.header, uint16(n))
_, err = p.writer.Write(p.header)
if err != nil {
return err
}
_, err = p.writer.Write(p.buffer[:n])
if err != nil {
return err
}
}
}
// processTCPToUDP TCP转UDP
func (p *RelayProcessor) processTCPToUDP(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// 读取2字节长度头
_, err := io.ReadFull(p.reader, p.header)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
// 获取包长度
packetLength := binary.BigEndian.Uint16(p.header)
// 如果包长度超过缓冲区大小,需要动态分配
if packetLength > uint16(len(p.buffer)) {
packetData := make([]byte, packetLength)
_, err = io.ReadFull(p.reader, packetData)
if err != nil {
return err
}
_, err = p.writer.Write(packetData)
} else {
// 使用可复用缓冲区
_, err = io.ReadFull(p.reader, p.buffer[:packetLength])
if err != nil {
return err
}
_, err = p.writer.Write(p.buffer[:packetLength])
}
if err != nil {
return err
}
}
}
// RTPProcessor RTP处理器
type RTPProcessor struct {
reader IRTPReader
writer RTPWriter
config *ForwardConfig
sendBuffer util.Buffer // 可复用的发送缓冲区
}
// NewRTPProcessor 创建RTP处理器
func NewRTPProcessor(reader IRTPReader, writer RTPWriter, config *ForwardConfig) *RTPProcessor {
return &RTPProcessor{
reader: reader,
writer: writer,
config: config,
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
}
}
// Process 处理RTP包
func (p *RTPProcessor) Process(ctx context.Context) error {
var packet rtp.Packet
var sequenceNumber uint16
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
err := p.reader.Read(&packet)
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("read RTP packet failed: %v", err)
}
// 检查源SSRC过滤
if p.config.Source.SSRC != 0 && packet.SSRC != p.config.Source.SSRC {
continue
}
// 保存原始序列号用于分片包
sequenceNumber = packet.SequenceNumber
// 检查是否需要分片
if len(packet.Payload) > (1460 - packet.MarshalSize()) {
err = p.processFragmentedPacket(&packet, sequenceNumber)
} else {
err = p.processSinglePacket(&packet)
}
if err != nil {
return err
}
}
}
// processSinglePacket 处理单个包
func (p *RTPProcessor) processSinglePacket(packet *rtp.Packet) error {
if p.config.Target.SSRC != 0 {
packet.SSRC = p.config.Target.SSRC
}
return p.writer.WritePacket(packet)
}
// processFragmentedPacket 处理分片包
func (p *RTPProcessor) processFragmentedPacket(packet *rtp.Packet, sequenceNumber uint16) error {
maxPayloadSize := 1460 - 12 // RTP头通常是12字节
payload := packet.Payload
// 标记第一个包
marker := packet.Marker
packet.Marker = false
for i := 0; i < len(payload); i += int(maxPayloadSize) {
end := i + int(maxPayloadSize)
if end > len(payload) {
end = len(payload)
// 最后一个分片,恢复原始标记
packet.Marker = marker
}
// 创建包含分片的新包
fragmentPacket := *packet
if p.config.Target.SSRC != 0 {
fragmentPacket.SSRC = p.config.Target.SSRC
}
fragmentPacket.SequenceNumber = sequenceNumber
sequenceNumber++
fragmentPacket.Payload = payload[i:end]
err := p.writer.WritePacket(&fragmentPacket)
if err != nil {
return fmt.Errorf("write RTP fragment failed: %v", err)
}
}
return nil
}
// Forward 执行转发
func (f *Forwarder) Forward(ctx context.Context) error {
// 建立连接
err := f.setupConnections()
if err != nil {
return err
}
defer f.cleanup()
// 检查是否为中继模式
if f.config.Relay {
processor := NewRelayProcessor(f.source, f.target, f.config.Source.Mode, f.config.Target.Mode)
return processor.Process(ctx)
}
// RTP处理模式
reader := f.createRTPReader()
writer := f.createRTPWriter()
processor := NewRTPProcessor(reader, writer, f.config)
return processor.Process(ctx)
}