diff --git a/internal/rtpbuffer/packet_factory.go b/internal/rtpbuffer/packet_factory.go index 4ab07fb..d3f5a12 100644 --- a/internal/rtpbuffer/packet_factory.go +++ b/internal/rtpbuffer/packet_factory.go @@ -11,6 +11,8 @@ import ( "github.com/pion/rtp" ) +const rtxSsrcByteLength = 2 + // PacketFactory allows custom logic around the handle of RTP Packets before they added to the RTPBuffer. // The NoOpPacketFactory doesn't copy packets, while the RetainablePacket will take a copy before adding type PacketFactory interface { @@ -68,32 +70,38 @@ func (m *PacketFactoryCopy) NewPacket(header *rtp.Header, payload []byte, rtxSsr if !ok { return nil, errFailedToCastPayloadPool } - - size := copy(*p.buffer, payload) - p.payload = (*p.buffer)[:size] + if rtxSsrc != 0 && rtxPayloadType != 0 { + size := copy((*p.buffer)[rtxSsrcByteLength:], payload) + p.payload = (*p.buffer)[:size+rtxSsrcByteLength] + } else { + size := copy(*p.buffer, payload) + p.payload = (*p.buffer)[:size] + } } if rtxSsrc != 0 && rtxPayloadType != 0 { - // Store the original sequence number and rewrite the sequence number. - originalSequenceNumber := p.header.SequenceNumber - p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber() + if payload == nil { + p.buffer, ok = m.payloadPool.Get().(*[]byte) + if !ok { + return nil, errFailedToCastPayloadPool + } + p.payload = (*p.buffer)[:rtxSsrcByteLength] + } + // Write the original sequence number at the beginning of the payload. + binary.BigEndian.PutUint16(p.payload, p.header.SequenceNumber) // Rewrite the SSRC. p.header.SSRC = rtxSsrc // Rewrite the payload type. p.header.PayloadType = rtxPayloadType - + // Rewrite the sequence number. + p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber() // Remove padding if present. - paddingLength := 0 if p.header.Padding && p.payload != nil && len(p.payload) > 0 { - paddingLength = int(p.payload[len(p.payload)-1]) + paddingLength := int(p.payload[len(p.payload)-1]) p.header.Padding = false + p.payload = (*p.buffer)[:len(p.payload)-paddingLength] } - - // Write the original sequence number at the beginning of the payload. - payload := make([]byte, 2) - binary.BigEndian.PutUint16(payload, originalSequenceNumber) - p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...) } return p, nil diff --git a/internal/rtpbuffer/rtpbuffer_test.go b/internal/rtpbuffer/rtpbuffer_test.go index 746fb5c..f1cac81 100644 --- a/internal/rtpbuffer/rtpbuffer_test.go +++ b/internal/rtpbuffer/rtpbuffer_test.go @@ -70,6 +70,66 @@ func TestRTPBuffer(t *testing.T) { } } +func TestRTPBuffer_WithRTX(t *testing.T) { + pm := NewPacketFactoryCopy() + for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} { + start := start + + sb, err := NewRTPBuffer(8) + require.NoError(t, err) + + add := func(nums ...uint16) { + for _, n := range nums { + seq := start + n + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq, PayloadType: 2}, []byte("originalcontent"), 1, 1) + require.NoError(t, err) + sb.Add(pkt) + } + } + + assertGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + packet := sb.Get(seq) + if packet == nil { + t.Errorf("packet not found: %d", seq) + continue + } + if packet.Header().SSRC != 1 && packet.Header().PayloadType != 1 { + t.Errorf("packet for %d returned with incorrect SSRC : %d and PayloadType: %d", seq, packet.Header().SSRC, packet.Header().PayloadType) + } + packet.Release() + } + } + assertNOTGet := func(nums ...uint16) { + t.Helper() + for _, n := range nums { + seq := start + n + packet := sb.Get(seq) + if packet != nil { + t.Errorf("packet found for %d: %d", seq, packet.Header().SequenceNumber) + } + } + } + + add(0, 1, 2, 3, 4, 5, 6, 7) + assertGet(0, 1, 2, 3, 4, 5, 6, 7) + + add(8) + assertGet(8) + assertNOTGet(0) + + add(10) + assertGet(10) + assertNOTGet(1, 2, 9) + + add(22) + assertGet(22) + assertNOTGet(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21) + } +} + func TestRTPBuffer_Overridden(t *testing.T) { // override original packet content and get pm := NewPacketFactoryCopy() @@ -98,3 +158,60 @@ func TestRTPBuffer_Overridden(t *testing.T) { require.Nil(t, sb.Get(1)) } + +func TestRTPBuffer_Overridden_WithRTX_AND_Padding(t *testing.T) { + // override original packet content and get + pm := NewPacketFactoryCopy() + sb, err := NewRTPBuffer(1) + require.NoError(t, err) + require.Equal(t, uint16(1), sb.size) + + originalBytes := []byte("originalContent\x01") + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1, Padding: true, SSRC: 2, PayloadType: 3}, originalBytes, 1, 1) + require.NoError(t, err) + sb.Add(pkt) + + // change payload + copy(originalBytes, "altered") + retrieved := sb.Get(1) + require.NotNil(t, retrieved) + require.Equal(t, "\x00\x01originalContent", string(retrieved.Payload())) + retrieved.Release() + require.Equal(t, 1, retrieved.count) + + // ensure original packet is released + pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 1, 1) + require.NoError(t, err) + sb.Add(pkt) + require.Equal(t, 0, retrieved.count) + + require.Nil(t, sb.Get(1)) +} + +func TestRTPBuffer_Overridden_WithRTX_NILPayload(t *testing.T) { + // override original packet content and get + pm := NewPacketFactoryCopy() + sb, err := NewRTPBuffer(1) + require.NoError(t, err) + require.Equal(t, uint16(1), sb.size) + + pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, nil, 1, 1) + require.NoError(t, err) + sb.Add(pkt) + + // change payload + + retrieved := sb.Get(1) + require.NotNil(t, retrieved) + require.Equal(t, "\x00\x01", string(retrieved.Payload())) + retrieved.Release() + require.Equal(t, 1, retrieved.count) + + // ensure original packet is released + pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, []byte("altered"), 1, 1) + require.NoError(t, err) + sb.Add(pkt) + require.Equal(t, 0, retrieved.count) + + require.Nil(t, sb.Get(1)) +}