diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 1d5745f..22d038b 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -4,7 +4,6 @@ package nack import ( - "encoding/binary" "sync" "github.com/pion/interceptor" @@ -19,7 +18,7 @@ type ResponderInterceptorFactory struct { } type packetFactory interface { - NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) + NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) } // NewInterceptor constructs a new ResponderInterceptor @@ -63,11 +62,6 @@ type ResponderInterceptor struct { type localStream struct { sendBuffer *sendBuffer rtpWriter interceptor.RTPWriter - - // Non-zero if Retransmissions should be sent on a distinct stream - rtxSsrc uint32 - rtxPayloadType uint8 - rtxSequencer rtp.Sequencer } // NewResponderInterceptor returns a new ResponderInterceptorFactor @@ -115,16 +109,13 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri sendBuffer, _ := newSendBuffer(n.size) n.streamsMu.Lock() n.streams[info.SSRC] = &localStream{ - sendBuffer: sendBuffer, - rtpWriter: writer, - rtxSsrc: info.SSRCRetransmission, - rtxPayloadType: info.PayloadTypeRetransmission, - rtxSequencer: rtp.NewRandomSequencer(), + sendBuffer: sendBuffer, + rtpWriter: writer, } n.streamsMu.Unlock() return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { - pkt, err := n.packetFactory.NewPacket(header, payload) + pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission) if err != nil { return 0, err } @@ -151,43 +142,8 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { for i := range nack.Nacks { nack.Nacks[i].Range(func(seq uint16) bool { if p := stream.sendBuffer.get(seq); p != nil { - if stream.rtxSsrc != 0 { - // Store the original sequence number and rewrite the sequence number. - originalSequenceNumber := p.Header().SequenceNumber - p.Header().SequenceNumber = stream.rtxSequencer.NextSequenceNumber() - - // Rewrite the SSRC. - p.Header().SSRC = stream.rtxSsrc - // Rewrite the payload type. - p.Header().PayloadType = stream.rtxPayloadType - - // Remove padding if present. - paddingLength := 0 - originPayload := p.Payload() - if p.Header().Padding { - paddingLength = int(originPayload[len(originPayload)-1]) - p.Header().Padding = false - } - - // Write the original sequence number at the beginning of the payload. - payload := make([]byte, 2) - binary.BigEndian.PutUint16(payload, originalSequenceNumber) - payload = append(payload, originPayload[:len(originPayload)-paddingLength]...) - - // Send RTX packet. - if _, err := stream.rtpWriter.Write(p.Header(), payload, interceptor.Attributes{}); err != nil { - n.log.Warnf("failed sending rtx packet: %+v", err) - } - - // Resore the Padding and SSRC. - if paddingLength > 0 { - p.Header().Padding = true - } - p.Header().SequenceNumber = originalSequenceNumber - } else { - if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { - n.log.Warnf("failed resending nacked packet: %+v", err) - } + if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { + n.log.Warnf("failed resending nacked packet: %+v", err) } p.Release() } diff --git a/pkg/nack/retainable_packet.go b/pkg/nack/retainable_packet.go index 31e9d83..ef05ed4 100644 --- a/pkg/nack/retainable_packet.go +++ b/pkg/nack/retainable_packet.go @@ -4,6 +4,7 @@ package nack import ( + "encoding/binary" "io" "sync" @@ -13,8 +14,9 @@ import ( const maxPayloadLen = 1460 type packetManager struct { - headerPool *sync.Pool - payloadPool *sync.Pool + headerPool *sync.Pool + payloadPool *sync.Pool + rtxSequencer rtp.Sequencer } func newPacketManager() *packetManager { @@ -30,16 +32,18 @@ func newPacketManager() *packetManager { return &buf }, }, + rtxSequencer: rtp.NewRandomSequencer(), } } -func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) { +func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) { if len(payload) > maxPayloadLen { return nil, io.ErrShortBuffer } p := &retainablePacket{ - onRelease: m.releasePacket, + onRelease: m.releasePacket, + sequenceNumber: header.SequenceNumber, // new packets have retain count of 1 count: 1, } @@ -62,6 +66,29 @@ func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainab p.payload = (*p.buffer)[:size] } + if rtxSsrc != 0 && rtxPayloadType != 0 { + // Store the original sequence number and rewrite the sequence number. + originalSequenceNumber := p.header.SequenceNumber + p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber() + + // Rewrite the SSRC. + p.header.SSRC = rtxSsrc + // Rewrite the payload type. + p.header.PayloadType = rtxPayloadType + + // Remove padding if present. + paddingLength := 0 + if p.header.Padding { + paddingLength = int(p.payload[len(p.payload)-1]) + p.header.Padding = false + } + + // Write the original sequence number at the beginning of the payload. + payload := make([]byte, 2) + binary.BigEndian.PutUint16(payload, originalSequenceNumber) + p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...) + } + return p, nil } @@ -74,12 +101,13 @@ func (m *packetManager) releasePacket(header *rtp.Header, payload *[]byte) { type noOpPacketFactory struct{} -func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) { +func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*retainablePacket, error) { return &retainablePacket{ - onRelease: f.releasePacket, - count: 1, - header: header, - payload: payload, + onRelease: f.releasePacket, + count: 1, + header: header, + payload: payload, + sequenceNumber: header.SequenceNumber, }, nil } @@ -96,6 +124,8 @@ type retainablePacket struct { header *rtp.Header buffer *[]byte payload []byte + + sequenceNumber uint16 } func (p *retainablePacket) Header() *rtp.Header { diff --git a/pkg/nack/send_buffer.go b/pkg/nack/send_buffer.go index e8e816f..2b3b076 100644 --- a/pkg/nack/send_buffer.go +++ b/pkg/nack/send_buffer.go @@ -46,7 +46,7 @@ func (s *sendBuffer) add(packet *retainablePacket) { s.m.Lock() defer s.m.Unlock() - seq := packet.Header().SequenceNumber + seq := packet.sequenceNumber if !s.started { s.packets[seq%s.size] = packet s.lastAdded = seq @@ -92,7 +92,7 @@ func (s *sendBuffer) get(seq uint16) *retainablePacket { pkt := s.packets[seq%s.size] if pkt != nil { - if pkt.Header().SequenceNumber != seq { + if pkt.sequenceNumber != seq { return nil } // already released diff --git a/pkg/nack/send_buffer_test.go b/pkg/nack/send_buffer_test.go index e221760..8e45f0f 100644 --- a/pkg/nack/send_buffer_test.go +++ b/pkg/nack/send_buffer_test.go @@ -21,7 +21,7 @@ func TestSendBuffer(t *testing.T) { add := func(nums ...uint16) { for _, n := range nums { seq := start + n - pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil) + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0) require.NoError(t, err) sb.add(pkt) } @@ -78,7 +78,7 @@ func TestSendBuffer_Overridden(t *testing.T) { require.Equal(t, uint16(1), sb.size) originalBytes := []byte("originalContent") - pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes) + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes, 0, 0) require.NoError(t, err) sb.add(pkt) @@ -91,7 +91,7 @@ func TestSendBuffer_Overridden(t *testing.T) { require.Equal(t, 1, retrieved.count) // ensure original packet is released - pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes) + pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 0, 0) require.NoError(t, err) sb.add(pkt) require.Equal(t, 0, retrieved.count) @@ -113,7 +113,7 @@ func TestSendBuffer_Race(t *testing.T) { add := func(nums ...uint16) { for _, n := range nums { seq := start + n - pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil) + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0) require.NoError(t, err) sb.add(pkt) }