Skip to content

Commit

Permalink
Fix race in a03526d
Browse files Browse the repository at this point in the history
Modify packets before adding them to NACK buffer.

WARNING: DATA RACE
Read at 0x00c00056d456 by goroutine 10741:
  github.com/pion/interceptor/pkg/nack.(*sendBuffer).get()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/send_buffer.go:95 +0x1e4
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).resendPackets.func1()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:153 +0x70
  github.com/pion/rtcp.(*NackPair).Range()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/transport_layer_nack.go:65 +0x43
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).resendPackets()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:152 +0x124
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).BindRTCPReader.func1.gowrap1()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:100 +0x44

Previous write at 0x00c00056d456 by goroutine 10735:
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).resendPackets.func1()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:186 +0x684
  github.com/pion/rtcp.(*NackPair).Range()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/transport_layer_nack.go:65 +0x43
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).resendPackets()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:152 +0x124
  github.com/pion/interceptor/pkg/nack.(*ResponderInterceptor).BindRTCPReader.func1.gowrap1()
      /home/runner/go/pkg/mod/github.com/pion/[email protected]/pkg/nack/responder_interceptor.go:100 +0x44
  • Loading branch information
Sean-Der committed Oct 5, 2024
1 parent a03526d commit bfda3e0
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 65 deletions.
56 changes: 6 additions & 50 deletions pkg/nack/responder_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package nack

import (
"encoding/binary"
"sync"

"github.com/pion/interceptor"
Expand All @@ -19,7 +18,7 @@ type ResponderInterceptorFactory struct {
}

type packetFactory interface {
NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error)
NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error)
}

// NewInterceptor constructs a new ResponderInterceptor
Expand Down Expand Up @@ -63,11 +62,6 @@ type ResponderInterceptor struct {
type localStream struct {
sendBuffer *sendBuffer
rtpWriter interceptor.RTPWriter

// Non-zero if Retransmissions should be sent on a distinct stream
rtxSsrc uint32
rtxPayloadType uint8
rtxSequencer rtp.Sequencer
}

// NewResponderInterceptor returns a new ResponderInterceptorFactor
Expand Down Expand Up @@ -115,16 +109,13 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri
sendBuffer, _ := newSendBuffer(n.size)
n.streamsMu.Lock()
n.streams[info.SSRC] = &localStream{
sendBuffer: sendBuffer,
rtpWriter: writer,
rtxSsrc: info.SSRCRetransmission,
rtxPayloadType: info.PayloadTypeRetransmission,
rtxSequencer: rtp.NewRandomSequencer(),
sendBuffer: sendBuffer,
rtpWriter: writer,
}
n.streamsMu.Unlock()

return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
pkt, err := n.packetFactory.NewPacket(header, payload)
pkt, err := n.packetFactory.NewPacket(header, payload, info.SSRCRetransmission, info.PayloadTypeRetransmission)
if err != nil {
return 0, err
}
Expand All @@ -151,43 +142,8 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) {
for i := range nack.Nacks {
nack.Nacks[i].Range(func(seq uint16) bool {
if p := stream.sendBuffer.get(seq); p != nil {
if stream.rtxSsrc != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.Header().SequenceNumber
p.Header().SequenceNumber = stream.rtxSequencer.NextSequenceNumber()

// Rewrite the SSRC.
p.Header().SSRC = stream.rtxSsrc
// Rewrite the payload type.
p.Header().PayloadType = stream.rtxPayloadType

// Remove padding if present.
paddingLength := 0
originPayload := p.Payload()
if p.Header().Padding {
paddingLength = int(originPayload[len(originPayload)-1])
p.Header().Padding = false
}

// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
payload = append(payload, originPayload[:len(originPayload)-paddingLength]...)

// Send RTX packet.
if _, err := stream.rtpWriter.Write(p.Header(), payload, interceptor.Attributes{}); err != nil {
n.log.Warnf("failed sending rtx packet: %+v", err)
}

// Resore the Padding and SSRC.
if paddingLength > 0 {
p.Header().Padding = true
}
p.Header().SequenceNumber = originalSequenceNumber
} else {
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)
}
if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil {
n.log.Warnf("failed resending nacked packet: %+v", err)
}
p.Release()
}
Expand Down
48 changes: 39 additions & 9 deletions pkg/nack/retainable_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package nack

import (
"encoding/binary"
"io"
"sync"

Expand All @@ -13,8 +14,9 @@ import (
const maxPayloadLen = 1460

type packetManager struct {
headerPool *sync.Pool
payloadPool *sync.Pool
headerPool *sync.Pool
payloadPool *sync.Pool
rtxSequencer rtp.Sequencer
}

func newPacketManager() *packetManager {
Expand All @@ -30,16 +32,18 @@ func newPacketManager() *packetManager {
return &buf
},
},
rtxSequencer: rtp.NewRandomSequencer(),
}
}

func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) {
func (m *packetManager) NewPacket(header *rtp.Header, payload []byte, rtxSsrc uint32, rtxPayloadType uint8) (*retainablePacket, error) {
if len(payload) > maxPayloadLen {
return nil, io.ErrShortBuffer
}

p := &retainablePacket{
onRelease: m.releasePacket,
onRelease: m.releasePacket,
sequenceNumber: header.SequenceNumber,
// new packets have retain count of 1
count: 1,
}
Expand All @@ -62,6 +66,29 @@ func (m *packetManager) NewPacket(header *rtp.Header, payload []byte) (*retainab
p.payload = (*p.buffer)[:size]
}

if rtxSsrc != 0 && rtxPayloadType != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.header.SequenceNumber
p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()

// Rewrite the SSRC.
p.header.SSRC = rtxSsrc
// Rewrite the payload type.
p.header.PayloadType = rtxPayloadType

// Remove padding if present.
paddingLength := 0
if p.header.Padding {
paddingLength = int(p.payload[len(p.payload)-1])
p.header.Padding = false
}

// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...)
}

return p, nil
}

Expand All @@ -74,12 +101,13 @@ func (m *packetManager) releasePacket(header *rtp.Header, payload *[]byte) {

type noOpPacketFactory struct{}

func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte) (*retainablePacket, error) {
func (f *noOpPacketFactory) NewPacket(header *rtp.Header, payload []byte, _ uint32, _ uint8) (*retainablePacket, error) {
return &retainablePacket{
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
onRelease: f.releasePacket,
count: 1,
header: header,
payload: payload,
sequenceNumber: header.SequenceNumber,
}, nil
}

Expand All @@ -96,6 +124,8 @@ type retainablePacket struct {
header *rtp.Header
buffer *[]byte
payload []byte

sequenceNumber uint16
}

func (p *retainablePacket) Header() *rtp.Header {
Expand Down
4 changes: 2 additions & 2 deletions pkg/nack/send_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (s *sendBuffer) add(packet *retainablePacket) {
s.m.Lock()
defer s.m.Unlock()

seq := packet.Header().SequenceNumber
seq := packet.sequenceNumber
if !s.started {
s.packets[seq%s.size] = packet
s.lastAdded = seq
Expand Down Expand Up @@ -92,7 +92,7 @@ func (s *sendBuffer) get(seq uint16) *retainablePacket {

pkt := s.packets[seq%s.size]
if pkt != nil {
if pkt.Header().SequenceNumber != seq {
if pkt.sequenceNumber != seq {
return nil
}
// already released
Expand Down
8 changes: 4 additions & 4 deletions pkg/nack/send_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestSendBuffer(t *testing.T) {
add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0)
require.NoError(t, err)
sb.add(pkt)
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func TestSendBuffer_Overridden(t *testing.T) {
require.Equal(t, uint16(1), sb.size)

originalBytes := []byte("originalContent")
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, originalBytes, 0, 0)
require.NoError(t, err)
sb.add(pkt)

Expand All @@ -91,7 +91,7 @@ func TestSendBuffer_Overridden(t *testing.T) {
require.Equal(t, 1, retrieved.count)

// ensure original packet is released
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes)
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 0, 0)
require.NoError(t, err)
sb.add(pkt)
require.Equal(t, 0, retrieved.count)
Expand All @@ -113,7 +113,7 @@ func TestSendBuffer_Race(t *testing.T) {
add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil)
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq}, nil, 0, 0)
require.NoError(t, err)
sb.add(pkt)
}
Expand Down

0 comments on commit bfda3e0

Please sign in to comment.