From 7046799f7cd5c06e0a735464bdf004a91ccf9e51 Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Fri, 6 Mar 2020 17:41:59 +0900 Subject: [PATCH] Add replay attack protection Implements RFC 6347 Section 4.1.2.6. Set config.ReplayProtectionWindow to change the size of the protection window. Default is 64. --- config.go | 6 + conn.go | 70 +++++++++- conn_test.go | 3 + internal/replaydetector/fixedbig.go | 78 +++++++++++ internal/replaydetector/fixedbig_test.go | 58 ++++++++ internal/replaydetector/replaydetector.go | 51 +++++++ .../replaydetector/replaydetector_test.go | 132 ++++++++++++++++++ replayprotection_test.go | 125 +++++++++++++++++ state.go | 4 + 9 files changed, 520 insertions(+), 7 deletions(-) create mode 100644 internal/replaydetector/fixedbig.go create mode 100644 internal/replaydetector/fixedbig_test.go create mode 100644 internal/replaydetector/replaydetector.go create mode 100644 internal/replaydetector/replaydetector_test.go create mode 100644 replayprotection_test.go diff --git a/config.go b/config.go index 06b5dfd01..059865f80 100644 --- a/config.go +++ b/config.go @@ -95,6 +95,12 @@ type Config struct { // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int + + // ReplayProtectionWindow is the size of the replay attack protection window. + // Duplication of the sequence number is checked in this window size. + // Packet with sequence number older than this value compared to the latest + // accepted packet will be discarded. (default is 64) + ReplayProtectionWindow int } func defaultConnectContextMaker() (context.Context, func()) { diff --git a/conn.go b/conn.go index 504896625..c518c72d3 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,7 @@ import ( "github.com/pion/dtls/v2/internal/closer" "github.com/pion/dtls/v2/internal/net/connctx" "github.com/pion/dtls/v2/internal/net/deadline" + "github.com/pion/dtls/v2/internal/replaydetector" "github.com/pion/logging" ) @@ -20,6 +21,8 @@ const ( cookieLength = 20 defaultNamedCurve = namedCurveX25519 inboundBufferSize = 8192 + // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 + defaultReplayProtectionWindow = 64 ) var invalidKeyingLabels = map[string]bool{ @@ -62,6 +65,8 @@ type Conn struct { cancelHandshakeReader func() fsm *handshakeFSM + + replayProtectionWindow uint } func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { @@ -96,6 +101,11 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient mtu = defaultMTU } + replayProtectionWindow := config.ReplayProtectionWindow + if replayProtectionWindow <= 0 { + replayProtectionWindow = defaultReplayProtectionWindow + } + c := &Conn{ nextConn: connctx.New(nextConn), fragmentBuffer: newFragmentBuffer(), @@ -113,11 +123,16 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient closed: closer.NewCloser(), cancelHandshaker: func() {}, + replayProtectionWindow: uint(replayProtectionWindow), + state: State{ isClient: isClient, }, } + c.setRemoteEpoch(0) + c.setLocalEpoch(0) + serverName := config.ServerName // Use host from conn address when serverName is not provided if isClient && serverName == "" && nextConn.RemoteAddr() != nil { @@ -159,6 +174,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient initialFlight = flight6 } initialFSMState = handshakeFinished + c.state = *initialState } else { if c.state.isClient { @@ -563,7 +579,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { var hasHandshake bool for _, p := range pkts { - hs, alert, err := c.handleIncomingPacket(p) + hs, alert, err := c.handleIncomingPacket(p, true) if alert != nil { if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil { if err == nil { @@ -602,7 +618,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.encryptedPackets = nil for _, p := range pkts { - _, alert, err := c.handleIncomingPacket(p) + _, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil { if err == nil { @@ -623,17 +639,50 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { return nil } -func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) { +func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) { // TODO: avoid separate unmarshal h := &recordLayerHeader{} if err := h.Unmarshal(buf); err != nil { return false, &alert{alertLevelFatal, alertDecodeError}, err } + // Validate epoch + remoteEpoch := c.getRemoteEpoch() + if h.epoch > remoteEpoch { + if h.epoch > remoteEpoch+1 { + c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", + h.epoch, h.sequenceNumber, + ) + return false, nil, nil + } + if enqueue { + c.log.Debug("received packet of next epoch, queuing packet") + c.encryptedPackets = append(c.encryptedPackets, buf) + } + return false, nil, nil + } + + // Anti-replay protection + for len(c.state.replayDetector) <= int(h.epoch) { + c.state.replayDetector = append(c.state.replayDetector, + replaydetector.New(c.replayProtectionWindow, maxSequenceNumber), + ) + } + markPacketAsValid, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber) + if !ok { + c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", + h.epoch, h.sequenceNumber, + ) + return false, nil, nil + } + + // Decrypt if h.epoch != 0 { if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debug("handleIncoming: Handshake not finished, queuing packet") + if enqueue { + c.encryptedPackets = append(c.encryptedPackets, buf) + c.log.Debug("handshake not finished, queuing packet") + } return false, nil, nil } @@ -652,6 +701,7 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) { c.log.Debugf("defragment failed: %s", err) return false, nil, nil } else if isHandshake { + markPacketAsValid() for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { rawHandshake := &handshake{} if err := rawHandshake.Unmarshal(out); err != nil { @@ -678,11 +728,14 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) { // Respond with a close_notify [RFC5246 Section 7.2.1] a = &alert{alertLevelWarning, alertCloseNotify} } + markPacketAsValid() return false, a, &errAlert{content} case *changeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() { - c.encryptedPackets = append(c.encryptedPackets, buf) - c.log.Debug("handleIncoming: CipherSuite not initialized, queuing packet") + if enqueue { + c.encryptedPackets = append(c.encryptedPackets, buf) + c.log.Debugf("CipherSuite not initialized, queuing packet") + } return false, nil, nil } @@ -691,6 +744,7 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) { if c.getRemoteEpoch()+1 == newRemoteEpoch { c.setRemoteEpoch(newRemoteEpoch) + markPacketAsValid() } case *applicationData: if h.epoch == 0 { @@ -701,6 +755,8 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) { return false, nil, io.EOF } + markPacketAsValid() + select { case c.decrypted <- content.data: case <-c.closed.Done(): diff --git a/conn_test.go b/conn_test.go index c5c436364..457e7c96e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -88,7 +88,10 @@ func TestRoutineLeakOnClose(t *testing.T) { func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() + return pipeConn(ca, cb) +} +func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { type result struct { c *Conn err error diff --git a/internal/replaydetector/fixedbig.go b/internal/replaydetector/fixedbig.go new file mode 100644 index 000000000..a571a1aad --- /dev/null +++ b/internal/replaydetector/fixedbig.go @@ -0,0 +1,78 @@ +package replaydetector + +import ( + "fmt" +) + +// fixedBigInt is the fix-sized multi-word integer. +type fixedBigInt struct { + bits []uint64 + n uint + msbMask uint64 +} + +// newFixedBigInt creates a new fix-sized multi-word int. +func newFixedBigInt(n uint) *fixedBigInt { + chunkSize := (n + 63) / 64 + if chunkSize == 0 { + chunkSize = 1 + } + return &fixedBigInt{ + bits: make([]uint64, chunkSize), + n: n, + msbMask: (1 << (64 - n%64)) - 1, + } +} + +// Lsh is the left shift operation. +func (s *fixedBigInt) Lsh(n uint) { + if n == 0 { + return + } + nChunk := int(n / 64) + nN := n % 64 + + for i := len(s.bits) - 1; i >= 0; i-- { + var carry uint64 + if i-nChunk >= 0 { + carry = s.bits[i-nChunk] << nN + if i-nChunk-1 >= 0 { + carry |= s.bits[i-nChunk-1] >> (64 - nN) + } + } + s.bits[i] = (s.bits[i] << n) | carry + } + s.bits[len(s.bits)-1] &= s.msbMask +} + +// Bit returns i-th bit of the fixedBigInt. +func (s *fixedBigInt) Bit(i uint) uint { + if i >= s.n { + return 0 + } + chunk := i / 64 + pos := i % 64 + if s.bits[chunk]&(1<= s.n { + return + } + chunk := i / 64 + pos := i % 64 + s.bits[chunk] |= 1 << pos +} + +// String returns string representation of fixedBigInt. +func (s *fixedBigInt) String() string { + var out string + for i := len(s.bits) - 1; i >= 0; i-- { + out += fmt.Sprintf("%016X", s.bits[i]) + } + return out +} diff --git a/internal/replaydetector/fixedbig_test.go b/internal/replaydetector/fixedbig_test.go new file mode 100644 index 000000000..716a53365 --- /dev/null +++ b/internal/replaydetector/fixedbig_test.go @@ -0,0 +1,58 @@ +package replaydetector + +import ( + "fmt" +) + +func Example_fixedBigInt_SetBit() { + bi := newFixedBigInt(224) + + bi.SetBit(0) + fmt.Println(bi.String()) + bi.Lsh(1) + fmt.Println(bi.String()) + + bi.Lsh(0) + fmt.Println(bi.String()) + + bi.SetBit(10) + fmt.Println(bi.String()) + bi.Lsh(20) + fmt.Println(bi.String()) + + bi.SetBit(80) + fmt.Println(bi.String()) + bi.Lsh(4) + fmt.Println(bi.String()) + + bi.SetBit(130) + fmt.Println(bi.String()) + bi.Lsh(64) + fmt.Println(bi.String()) + + bi.SetBit(7) + fmt.Println(bi.String()) + + bi.Lsh(129) + fmt.Println(bi.String()) + + for i := 0; i < 256; i++ { + bi.Lsh(1) + bi.SetBit(0) + } + fmt.Println(bi.String()) + + // output: + // 0000000000000000000000000000000000000000000000000000000000000001 + // 0000000000000000000000000000000000000000000000000000000000000002 + // 0000000000000000000000000000000000000000000000000000000000000002 + // 0000000000000000000000000000000000000000000000000000000000000402 + // 0000000000000000000000000000000000000000000000000000000040200000 + // 0000000000000000000000000000000000000000000100000000000040200000 + // 0000000000000000000000000000000000000000001000000000000402000000 + // 0000000000000000000000000000000400000000001000000000000402000000 + // 0000000000000004000000000010000000000004020000000000000000000000 + // 0000000000000004000000000010000000000004020000000000000000000080 + // 0000000004000000000000000000010000000000000000000000000000000000 + // 00000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF +} diff --git a/internal/replaydetector/replaydetector.go b/internal/replaydetector/replaydetector.go new file mode 100644 index 000000000..ac283a683 --- /dev/null +++ b/internal/replaydetector/replaydetector.go @@ -0,0 +1,51 @@ +// Package replaydetector provides packet replay detection algorithm. +package replaydetector + +// ReplayDetector is the interface of sequence replay detector. +type ReplayDetector interface { + // Check returns true if given sequence number is not replayed. + // Call accept() to mark the packet is received properly. + Check(seq uint64) (accept func(), ok bool) +} + +type slidingWindowDetector struct { + latestSeq uint64 + maxSeq uint64 + windowSize uint + mask *fixedBigInt +} + +// New creates ReplayDetector. +func New(windowSize uint, maxSeq uint64) ReplayDetector { + return &slidingWindowDetector{ + maxSeq: maxSeq, + windowSize: windowSize, + mask: newFixedBigInt(windowSize), + } +} + +func (d *slidingWindowDetector) Check(seq uint64) (accept func(), ok bool) { + if seq > d.maxSeq { + // Exceeded upper limit. + return func() {}, false + } + + if seq <= d.latestSeq { + if d.latestSeq > uint64(d.windowSize)+seq { + return func() {}, false + } + if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { + // The sequence number is duplicated. + return func() {}, false + } + } + + return func() { + if seq > d.latestSeq { + // Update the head of the window. + d.mask.Lsh(uint(seq - d.latestSeq)) + d.latestSeq = seq + } + d.mask.SetBit(uint(d.latestSeq - seq)) + }, true +} diff --git a/internal/replaydetector/replaydetector_test.go b/internal/replaydetector/replaydetector_test.go new file mode 100644 index 000000000..4ecc59c1a --- /dev/null +++ b/internal/replaydetector/replaydetector_test.go @@ -0,0 +1,132 @@ +package replaydetector + +import ( + "reflect" + "testing" +) + +func TestReplayDetector(t *testing.T) { + const largeSeq = 0x100000000000 + cases := map[string]struct { + windowSize uint + maxSeq uint64 + input []uint64 + valid []bool + expected []uint64 + }{ + "Continuous": {16, 0x0000FFFFFFFFFFFF, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, true, true, true, true, true, + true, + }, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, + }, + "ValidLargeJump": {16, 0x0000FFFFFFFFFFFF, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, largeSeq, 11, largeSeq + 1, largeSeq + 2, largeSeq + 3}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, + }, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, largeSeq, largeSeq + 1, largeSeq + 2, largeSeq + 3}, + }, + "InvalidLargeJump": {16, 0x0000FFFFFFFFFFFF, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, largeSeq, 11, 12, 13, 14, 15}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + false, true, true, true, true, true, + }, + []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15}, + }, + "DuplicateAfterValidJump": {196, 0x0000FFFFFFFFFFFF, + []uint64{0, 1, 2, 129, 0, 1, 2}, + []bool{ + true, true, true, true, true, true, true, + }, + []uint64{0, 1, 2, 129}, + }, + "DuplicateAfterInvalidJump": {196, 0x0000FFFFFFFFFFFF, + []uint64{0, 1, 2, 128, 0, 1, 2}, + []bool{ + true, true, true, false, true, true, true, + }, + []uint64{0, 1, 2}, + }, + "ContinuousOffset": {16, 0x0000FFFFFFFFFFFF, + []uint64{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, + }, + []uint64{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114}, + }, + "Reordered": {128, 0x0000FFFFFFFFFFFF, + []uint64{96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, true, + }, + []uint64{96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120}, + }, + "Old": {100, 0x0000FFFFFFFFFFFF, + []uint64{24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 8, 16}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, true, + }, + []uint64{24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128}, + }, + "ReplayedLater": {128, 0x0000FFFFFFFFFFFF, + []uint64{16, 32, 48, 64, 80, 96, 112, 128, 16, 32, 48, 64, 80, 96, 112, 128}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, true, + }, + []uint64{16, 32, 48, 64, 80, 96, 112, 128}, + }, + "ReplayedQuick": {128, 0x0000FFFFFFFFFFFF, + []uint64{16, 16, 32, 32, 48, 48, 64, 64, 80, 80, 96, 96, 112, 112, 128, 128}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + true, true, true, true, true, true, + }, + []uint64{16, 32, 48, 64, 80, 96, 112, 128}, + }, + "Strict": {0, 0x0000FFFFFFFFFFFF, + []uint64{1, 3, 2, 4, 5, 6, 7, 8, 9, 10}, + []bool{ + true, true, true, true, true, true, true, true, true, true, + }, + []uint64{1, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + "Overflow": {128, 0x0000FFFFFFFFFFFF, + []uint64{0x0000FFFFFFFFFFFE, 0x0000FFFFFFFFFFFF, 0x0001000000000000, 0x0001000000000001}, + []bool{ + true, true, true, true, + }, + []uint64{0x0000FFFFFFFFFFFE, 0x0000FFFFFFFFFFFF}, + }, + } + for name, c := range cases { + c := c + t.Run(name, func(t *testing.T) { + det := New(c.windowSize, c.maxSeq) + var out []uint64 + for i, seq := range c.input { + accept, ok := det.Check(seq) + if ok { + if c.valid[i] { + out = append(out, seq) + accept() + } + } + } + if !reflect.DeepEqual(c.expected, out) { + t.Errorf("Wrong replay detection result:\nexpected: %v\ngot: %v", + c.expected, out, + ) + } + }) + } +} diff --git a/replayprotection_test.go b/replayprotection_test.go new file mode 100644 index 000000000..116250322 --- /dev/null +++ b/replayprotection_test.go @@ -0,0 +1,125 @@ +package dtls + +import ( + "net" + "reflect" + "sync" + "testing" + "time" + + "github.com/pion/dtls/v2/internal/net/dpipe" + "github.com/pion/transport/test" +) + +func TestReplayProtection(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(5 * time.Second) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + c0, c1 := dpipe.Pipe() + c2, c3 := dpipe.Pipe() + conn := []net.Conn{c0, c1, c2, c3} + + var wgReplays, wgRoutines sync.WaitGroup + + replayer := func(ca, cb net.Conn) { + defer wgRoutines.Done() + // Man in the middle + for { + b := make([]byte, 2048) + n, rerr := ca.Read(b) + if rerr != nil { + return + } + if _, werr := cb.Write(b[:n]); werr != nil { + t.Error(werr) + return + } + + wgReplays.Add(1) + go func() { + defer wgReplays.Done() + // Replay bit later + time.Sleep(time.Millisecond) + if _, werr := cb.Write(b[:n]); werr != nil { + t.Error(werr) + } + }() + } + } + wgRoutines.Add(2) + go replayer(conn[1], conn[2]) + go replayer(conn[2], conn[1]) + + ca, cb, err := pipeConn(conn[0], conn[3]) + if err != nil { + t.Fatal(err) + } + + const numMsgs = 10 + + var received [2][][]byte + for i, c := range []net.Conn{ca, cb} { + i := i + c := c + wgRoutines.Add(1) + wgReplays.Add(1) // Keep locked until the final message + var lastMsgDone sync.Once + go func() { + defer wgRoutines.Done() + for { + b := make([]byte, 2048) + n, rerr := c.Read(b) + if rerr != nil { + return + } + received[i] = append(received[i], b[:n]) + if b[0] == numMsgs-1 { + // Final message received + lastMsgDone.Do(func() { + wgReplays.Done() + }) + } + } + }() + } + + var sent [][]byte + for i := 0; i < numMsgs; i++ { + data := []byte{byte(i)} + sent = append(sent, data) + if _, werr := ca.Write(data); werr != nil { + t.Error(werr) + return + } + if _, werr := cb.Write(data); werr != nil { + t.Error(werr) + return + } + } + wgReplays.Wait() + time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent + + for i := 0; i < 4; i++ { + if err := conn[i].Close(); err != nil { + t.Error(err) + } + } + if err := ca.Close(); err != nil { + t.Error(err) + } + if err := cb.Close(); err != nil { + t.Error(err) + } + wgRoutines.Wait() + + for _, r := range received { + if !reflect.DeepEqual(sent, r) { + t.Errorf("Received data differs, expected: %v, got: %v", sent, r) + } + } +} diff --git a/state.go b/state.go index a939f73be..63659b797 100644 --- a/state.go +++ b/state.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/gob" "sync/atomic" + + "github.com/pion/dtls/v2/internal/replaydetector" ) // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler @@ -33,6 +35,8 @@ type State struct { localVerifyData []byte // cached VerifyData localKeySignature []byte // cached keySignature remoteCertificateVerified bool + + replayDetector []replaydetector.ReplayDetector } type serializedState struct {