Skip to content

Commit

Permalink
Add replay attack protection
Browse files Browse the repository at this point in the history
Implements RFC 6347 Section 4.1.2.6.
Set config.ReplayProtectionWindow to change the size of the
protection window. Default is 64.
  • Loading branch information
at-wat committed Mar 6, 2020
1 parent 3b0a286 commit 7432d82
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 24 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ type Config struct {
// MTU is the length at which handshake messages will be fragmented to
// fit within the maximum transmission unit (default is 1200 bytes)
MTU int

// ReplayProtectionWindow is the size of the replay attack protection window.
// Duplication of the sequence number is checked in this window size.
// Packet with sequence number older than this value compared to the latest
// accepted packet will be discarded. (default is 64)
ReplayProtectionWindow int
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
86 changes: 72 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/pion/dtls/v2/internal/closer"
"github.com/pion/dtls/v2/internal/net/connctx"
"github.com/pion/dtls/v2/internal/net/deadline"
"github.com/pion/dtls/v2/internal/replaydetector"
"github.com/pion/logging"
)

Expand Down Expand Up @@ -62,6 +63,8 @@ type Conn struct {
cancelHandshakeReader func()

fsm *handshakeFSM

replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
Expand Down Expand Up @@ -89,13 +92,22 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
loggerFactory = logging.NewDefaultLoggerFactory()
}

logger := loggerFactory.NewLogger("dtls")
loggerName := "dtls server"
if isClient {
loggerName = "dtls client"
}
logger := loggerFactory.NewLogger(loggerName)

mtu := config.MTU
if mtu <= 0 {
mtu = defaultMTU
}

replayProtectionWindow := config.ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = 64
}

c := &Conn{
nextConn: connctx.New(nextConn),
fragmentBuffer: newFragmentBuffer(),
Expand All @@ -113,11 +125,16 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
closed: closer.NewCloser(),
cancelHandshaker: func() {},

replayProtectionWindow: uint(replayProtectionWindow),

state: State{
isClient: isClient,
},
}

c.setRemoteEpoch(0)
c.setLocalEpoch(0)

serverName := config.ServerName
// Use host from conn address when serverName is not provided
if isClient && serverName == "" && nextConn.RemoteAddr() != nil {
Expand Down Expand Up @@ -159,6 +176,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
} else {
if c.state.isClient {
Expand Down Expand Up @@ -380,8 +398,8 @@ func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
return err
}

c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
srvCliStr(c.state.isClient), h.handshakeHeader.handshakeType.String(),
c.log.Tracef("[handshake] -> %s (epoch: %d, seq: %d)",
h.handshakeHeader.handshakeType.String(),
p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence)
c.handshakeCache.push(handshakeRaw[recordLayerHeaderSize:], p.record.recordLayerHeader.epoch, h.handshakeHeader.messageSequence, h.handshakeHeader.handshakeType, c.state.isClient)

Expand Down Expand Up @@ -563,7 +581,7 @@ func (c *Conn) readAndBuffer(ctx context.Context) error {

var hasHandshake bool
for _, p := range pkts {
hs, alert, err := c.handleIncomingPacket(p)
hs, alert, err := c.handleIncomingPacket(p, true)
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
Expand Down Expand Up @@ -602,7 +620,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
c.encryptedPackets = nil

for _, p := range pkts {
_, alert, err := c.handleIncomingPacket(p)
_, alert, err := c.handleIncomingPacket(p, false) // don't re-enqueue
if alert != nil {
if alertErr := c.notify(ctx, alert.alertLevel, alert.alertDescription); alertErr != nil {
if err == nil {
Expand All @@ -623,24 +641,57 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert, error) {
// TODO: avoid separate unmarshal
h := &recordLayerHeader{}
if err := h.Unmarshal(buf); err != nil {
return false, &alert{alertLevelFatal, alertDecodeError}, err
}

// Validate epoch
remoteEpoch := c.getRemoteEpoch()
if h.epoch > remoteEpoch {
if h.epoch > remoteEpoch+1 {
c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
}
return false, nil, nil
}

// Anti-replay protection
for len(c.state.replayDetector) <= int(h.epoch) {
c.state.replayDetector = append(c.state.replayDetector,
replaydetector.New(c.replayProtectionWindow, maxSequenceNumber),
)
}
accept, ok := c.state.replayDetector[int(h.epoch)].Check(h.sequenceNumber)
if !ok {
c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
h.epoch, h.sequenceNumber,
)
return false, nil, nil
}

// Decrypt
if h.epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handleIncoming: Handshake not finished, queuing packet")
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
}
return false, nil, nil
}

var err error
buf, err = c.state.cipherSuite.decrypt(buf)
if err != nil {
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
c.log.Debugf("decrypt failed: %s", err)
return false, nil, nil
}
}
Expand All @@ -652,10 +703,11 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
c.log.Debugf("defragment failed: %s", err)
return false, nil, nil
} else if isHandshake {
accept()
for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
rawHandshake := &handshake{}
if err := rawHandshake.Unmarshal(out); err != nil {
c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
c.log.Debugf("handshake parse failed: %s", err)
continue
}

Expand All @@ -672,25 +724,29 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {

switch content := r.content.(type) {
case *alert:
c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
c.log.Tracef("<- %s", content.String())
var a *alert
if content.alertDescription == alertCloseNotify {
// Respond with a close_notify [RFC5246 Section 7.2.1]
a = &alert{alertLevelWarning, alertCloseNotify}
}
accept()
return false, a, &errAlert{content}
case *changeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.isInitialized() {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handleIncoming: CipherSuite not initialized, queuing packet")
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
return false, nil, nil
}

newRemoteEpoch := h.epoch + 1
c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
c.log.Tracef("<- ChangeCipherSpec (epoch: %d)", newRemoteEpoch)

if c.getRemoteEpoch()+1 == newRemoteEpoch {
c.setRemoteEpoch(newRemoteEpoch)
accept()
}
case *applicationData:
if h.epoch == 0 {
Expand All @@ -701,6 +757,8 @@ func (c *Conn) handleIncomingPacket(buf []byte) (bool, *alert, error) {
return false, nil, io.EOF
}

accept()

select {
case c.decrypted <- content.data:
case <-c.closed.Done():
Expand Down
3 changes: 3 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ func TestRoutineLeakOnClose(t *testing.T) {
func pipeMemory() (*Conn, *Conn, error) {
// In memory pipe
ca, cb := dpipe.Pipe()
return pipeConn(ca, cb)
}

func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) {
type result struct {
c *Conn
err error
Expand Down
13 changes: 3 additions & 10 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ type flightConn interface {
handleQueuedPackets(context.Context) error
}

func srvCliStr(isClient bool) string {
if isClient {
return "client"
}
return "server"
}

func newHandshakeFSM(
s *State, cache *handshakeCache, cfg *handshakeConfig,
initialFlight flightVal,
Expand All @@ -142,7 +135,7 @@ func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState hands
close(s.closed)
}()
for {
s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
s.cfg.log.Tracef("[handshake] %s: %s", s.currentFlight.String(), state.String())
if s.cfg.onFlightState != nil {
s.cfg.onFlightState(s.currentFlight, state)
}
Expand Down Expand Up @@ -209,7 +202,7 @@ func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeStat
}
}
if epoch != nextEpoch {
s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
s.cfg.log.Tracef("[handshake] -> changeCipherSpec (epoch: %d)", nextEpoch)
c.setLocalEpoch(nextEpoch)
}
return handshakeSending, nil
Expand Down Expand Up @@ -257,7 +250,7 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState,
if nextFlight == 0 {
break
}
s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
s.cfg.log.Tracef("[handshake] %s -> %s", s.currentFlight.String(), nextFlight.String())
if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
return handshakeFinished, nil
}
Expand Down
59 changes: 59 additions & 0 deletions internal/replaydetector/replaydetector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Package replaydetector provides packet replay detection algorithm.
package replaydetector

// ReplayDetector is the interface of sequence replay detector.
type ReplayDetector interface {
// Check returns true if given sequence number is not replayed.
// Call accept() to mark the packet is received properly.
Check(seq uint64) (accept func(), ok bool)
}

type slidingWindowDetector struct {
latestSeq uint64
maxSeq uint64
windowSize uint64
mask []uint64
}

// New creates ReplayDetector.
func New(windowSize uint, maxSeq uint64) ReplayDetector {
chunkSize := (windowSize + 63) / 64
if chunkSize == 0 {
chunkSize = 1
}
return &slidingWindowDetector{
windowSize: uint64(windowSize),
maxSeq: maxSeq,
mask: make([]uint64, chunkSize),
}
}

func (d *slidingWindowDetector) Check(seq uint64) (accept func(), ok bool) {
if seq+d.windowSize < d.latestSeq || seq > d.maxSeq {
// Older than window size or exceeding upper limit.
return func() {}, false
}
if seq > d.latestSeq {
// Update the head of the window.
shift := seq - d.latestSeq
for i := len(d.mask) - 1; i > 0; i-- {
d.mask[i] <<= shift
d.mask[i] |= d.mask[i-1] >> (64 - shift)
}
d.mask[0] <<= shift
d.latestSeq = seq
}

chunk := (d.latestSeq - seq) / 64
pos := (d.latestSeq - seq) % 64
bit := uint64(1) << pos

if d.mask[chunk]&bit != 0 {
// The sequence number is duplicated.
return func() {}, false
}

return func() {
d.mask[chunk] |= bit
}, true
}
Loading

0 comments on commit 7432d82

Please sign in to comment.