Skip to content

Commit

Permalink
Add replay attack protection
Browse files Browse the repository at this point in the history
Implements RFC 6347 Section 4.1.2.6.
Set config.ReplayProtectionWindow to change the size of the
protection window. Default is 64.
  • Loading branch information
at-wat committed Mar 7, 2020
1 parent 45cabe8 commit 7046799
Show file tree
Hide file tree
Showing 9 changed files with 520 additions and 7 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ type Config struct {
// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int

// ReplayProtectionWindow is the size of the replay attack protection window.
// Duplication of the sequence number is checked in this window size.
// Packet with sequence number older than this value compared to the latest
// accepted packet will be discarded. (default is 64)
ReplayProtectionWindow int
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
70 changes: 63 additions & 7 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/connctx"
"github.com/pion/dtls/v2/internal/net/deadline"
"github.com/pion/dtls/v2/internal/replaydetector"
"github.com/pion/logging"
)

Expand All @@ -20,6 +21,8 @@ const (
cookieLength = 20
defaultNamedCurve = namedCurveX25519
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
)

var invalidKeyingLabels = map[string]bool{
Expand Down Expand Up @@ -62,6 +65,8 @@ type Conn struct {
cancelHandshakeReader func()

fsm *handshakeFSM

replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
Expand Down Expand Up @@ -96,6 +101,11 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
mtu = defaultMTU
}

replayProtectionWindow := config.ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}

c := &Conn{
nextConn: connctx.New(nextConn),
fragmentBuffer: newFragmentBuffer(),
Expand All @@ -113,11 +123,16 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)

serverName := config.ServerName
// Use host from conn address when serverName is not provided
if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
Expand Down Expand Up @@ -159,6 +174,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
} else {
if c.state.isClient {
Expand Down Expand Up @@ -563,7 +579,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {

var hasHandshake bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(p)
hs, alert, err := c.handleIncomingPacket(p, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
Expand Down Expand Up @@ -602,7 +618,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
c.encryptedPackets = nil

for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(p)
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
Expand All @@ -623,17 +639,50 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) {
// TODO: avoid separate unmarshal
h := &recordLayerHeader{}
if err := h.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
}

// Validate epoch
remoteEpoch := c.getRemoteEpoch()
if h.epoch > remoteEpoch {
if h.epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
}
return false, nil, nil
}

// Anti-replay protection
for len(c.state.replayDetector) <= int(h.epoch) {
c.state.replayDetector = append(c.state.replayDetector,
replaydetector.New(c.replayProtectionWindow, maxSequenceNumber),
)
}
markPacketAsValid, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber)
if !ok {
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}

// Decrypt
if h.epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handleIncoming: Handshake not finished, queuing packet")
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
}
return false, nil, nil
}

Expand All @@ -652,6 +701,7 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
} else if isHandshake {
markPacketAsValid()
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
rawHandshake := &handshake{}
if err := rawHandshake.Unmarshal(out); err != nil {
Expand All @@ -678,11 +728,14 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
// Respond with a close_notify [RFC5246 Section 7.2.1]
a = &alert{alertLevelWarning, alertCloseNotify}
}
markPacketAsValid()
return false, a, &errAlert{content}
case *changeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handleIncoming: CipherSuite not initialized, queuing packet")
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
return false, nil, nil
}

Expand All @@ -691,6 +744,7 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {

if c.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
markPacketAsValid()
}
case *applicationData:
if h.epoch == 0 {
Expand All @@ -701,6 +755,8 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
return false, nil, io.EOF
}

markPacketAsValid()

select {
case c.decrypted <- content.data:
case <-c.closed.Done():
Expand Down
3 changes: 3 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ func TestRoutineLeakOnClose(t *testing.T) {
func pipeMemory() (*Conn, *Conn, error) {
// In memory pipe
ca, cb := dpipe.Pipe()
return pipeConn(ca, cb)
}

func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) {
type result struct {
c *Conn
err error
Expand Down
78 changes: 78 additions & 0 deletions internal/replaydetector/fixedbig.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package replaydetector

import (
"fmt"
)

// fixedBigInt is the fix-sized multi-word integer.
type fixedBigInt struct {
bits []uint64
n uint
msbMask uint64
}

// newFixedBigInt creates a new fix-sized multi-word int.
func newFixedBigInt(n uint) *fixedBigInt {
chunkSize := (n + 63) / 64
if chunkSize == 0 {
chunkSize = 1
}
return &fixedBigInt{
bits: make([]uint64, chunkSize),
n: n,
msbMask: (1 << (64 - n%64)) - 1,
}
}

// Lsh is the left shift operation.
func (s *fixedBigInt) Lsh(n uint) {
if n == 0 {
return
}
nChunk := int(n / 64)
nN := n % 64

for i := len(s.bits) - 1; i >= 0; i-- {
var carry uint64
if i-nChunk >= 0 {
carry = s.bits[i-nChunk] << nN
if i-nChunk-1 >= 0 {
carry |= s.bits[i-nChunk-1] >> (64 - nN)
}
}
s.bits[i] = (s.bits[i] << n) | carry
}
s.bits[len(s.bits)-1] &= s.msbMask
}

// Bit returns i-th bit of the fixedBigInt.
func (s *fixedBigInt) Bit(i uint) uint {
if i >= s.n {
return 0
}
chunk := i / 64
pos := i % 64
if s.bits[chunk]&(1<<pos) != 0 {
return 1
}
return 0
}

// SetBit sets i-th bit to 1.
func (s *fixedBigInt) SetBit(i uint) {
if i >= s.n {
return
}
chunk := i / 64
pos := i % 64
s.bits[chunk] |= 1 << pos
}

// String returns string representation of fixedBigInt.
func (s *fixedBigInt) String() string {
var out string
for i := len(s.bits) - 1; i >= 0; i-- {
out += fmt.Sprintf("%016X", s.bits[i])
}
return out
}
58 changes: 58 additions & 0 deletions internal/replaydetector/fixedbig_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package replaydetector

import (
"fmt"
)

func Example_fixedBigInt_SetBit() {
bi := newFixedBigInt(224)

bi.SetBit(0)
fmt.Println(bi.String())
bi.Lsh(1)
fmt.Println(bi.String())

bi.Lsh(0)
fmt.Println(bi.String())

bi.SetBit(10)
fmt.Println(bi.String())
bi.Lsh(20)
fmt.Println(bi.String())

bi.SetBit(80)
fmt.Println(bi.String())
bi.Lsh(4)
fmt.Println(bi.String())

bi.SetBit(130)
fmt.Println(bi.String())
bi.Lsh(64)
fmt.Println(bi.String())

bi.SetBit(7)
fmt.Println(bi.String())

bi.Lsh(129)
fmt.Println(bi.String())

for i := 0; i < 256; i++ {
bi.Lsh(1)
bi.SetBit(0)
}
fmt.Println(bi.String())

// output:
// 0000000000000000000000000000000000000000000000000000000000000001
// 0000000000000000000000000000000000000000000000000000000000000002
// 0000000000000000000000000000000000000000000000000000000000000002
// 0000000000000000000000000000000000000000000000000000000000000402
// 0000000000000000000000000000000000000000000000000000000040200000
// 0000000000000000000000000000000000000000000100000000000040200000
// 0000000000000000000000000000000000000000001000000000000402000000
// 0000000000000000000000000000000400000000001000000000000402000000
// 0000000000000004000000000010000000000004020000000000000000000000
// 0000000000000004000000000010000000000004020000000000000000000080
// 0000000004000000000000000000010000000000000000000000000000000000
// 00000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
}
51 changes: 51 additions & 0 deletions internal/replaydetector/replaydetector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Package replaydetector provides packet replay detection algorithm.
package replaydetector

// ReplayDetector is the interface of sequence replay detector.
type ReplayDetector interface {
// Check returns true if given sequence number is not replayed.
// Call accept() to mark the packet is received properly.
Check(seq uint64) (accept func(), ok bool)
}

type slidingWindowDetector struct {
latestSeq uint64
maxSeq uint64
windowSize uint
mask *fixedBigInt
}

// New creates ReplayDetector.
func New(windowSize uint, maxSeq uint64) ReplayDetector {
return &slidingWindowDetector{
maxSeq: maxSeq,
windowSize: windowSize,
mask: newFixedBigInt(windowSize),
}
}

func (d *slidingWindowDetector) Check(seq uint64) (accept func(), ok bool) {
if seq > d.maxSeq {
// Exceeded upper limit.
return func() {}, false
}

if seq <= d.latestSeq {
if d.latestSeq > uint64(d.windowSize)+seq {
return func() {}, false
}
if d.mask.Bit(uint(d.latestSeq-seq)) != 0 {
// The sequence number is duplicated.
return func() {}, false
}
}

return func() {
if seq > d.latestSeq {
// Update the head of the window.
d.mask.Lsh(uint(seq - d.latestSeq))
d.latestSeq = seq
}
d.mask.SetBit(uint(d.latestSeq - seq))
}, true
}
Loading

0 comments on commit 7046799

Please sign in to comment.