Skip to content

Commit

Permalink
Remove profile specific constants
Browse files Browse the repository at this point in the history
key+salt len should be determined by profile

Relates to #85
  • Loading branch information
Sean-Der committed Jul 15, 2020
1 parent add2176 commit 50c4c01
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 31 deletions.
39 changes: 22 additions & 17 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ import (
"github.com/pion/transport/replaydetector"
)

// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite
type ProtectionProfile uint16

// Supported protection profiles
const (
ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001
)

const (
labelSRTPEncryption = 0x00
labelSRTPAuthenticationTag = 0x01
Expand All @@ -29,13 +21,9 @@ const (
labelSRTCPAuthenticationTag = 0x04
labelSRTCPSalt = 0x05

keyLen = 16
saltLen = 14

maxROCDisorder = 100
maxSequenceNumber = 65535

authTagSize = 10
srtcpIndexSize = 4
)

Expand All @@ -61,6 +49,7 @@ type srtcpSSRCState struct {
type Context struct {
masterKey []byte
masterSalt []byte
authTagLen int

srtpSSRCStates map[uint32]*srtpSSRCState
srtpSessionKey []byte
Expand Down Expand Up @@ -89,6 +78,21 @@ type Context struct {
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
//
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
if err != nil {
return nil, err
}

saltLen, err := profile.saltLen()
if err != nil {
return nil, err
}

authTagLen, err := profile.authTagLen()
if err != nil {
return nil, err
}

if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("SRTP Master Key must be len %d, got %d", masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
Expand All @@ -98,6 +102,7 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts
c = &Context{
masterKey: masterKey,
masterSalt: masterSalt,
authTagLen: authTagLen,
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
}
Expand All @@ -115,7 +120,7 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts

if c.srtpSessionKey, err = c.generateSessionKey(labelSRTPEncryption); err != nil {
return nil, err
} else if c.srtpSessionSalt, err = c.generateSessionSalt(labelSRTPSalt); err != nil {
} else if c.srtpSessionSalt, err = c.generateSessionSalt(labelSRTPSalt, saltLen); err != nil {
return nil, err
} else if c.srtpSessionAuthTag, err = c.generateSessionAuthTag(labelSRTPAuthenticationTag); err != nil {
return nil, err
Expand All @@ -127,7 +132,7 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts

if c.srtcpSessionKey, err = c.generateSessionKey(labelSRTCPEncryption); err != nil {
return nil, err
} else if c.srtcpSessionSalt, err = c.generateSessionSalt(labelSRTCPSalt); err != nil {
} else if c.srtcpSessionSalt, err = c.generateSessionSalt(labelSRTCPSalt, saltLen); err != nil {
return nil, err
} else if c.srtcpSessionAuthTag, err = c.generateSessionAuthTag(labelSRTCPAuthenticationTag); err != nil {
return nil, err
Expand Down Expand Up @@ -166,7 +171,7 @@ func (c *Context) generateSessionKey(label byte) ([]byte, error) {
return sessionKey, nil
}

func (c *Context) generateSessionSalt(label byte) ([]byte, error) {
func (c *Context) generateSessionSalt(label byte, saltLen int) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#appendix-B.3
// The input block for AES-CM is generated by exclusive-oring the master salt with
// the concatenation of the encryption salt label
Expand Down Expand Up @@ -268,7 +273,7 @@ func (c *Context) generateSrtpAuthTag(buf []byte, roc uint32) ([]byte, error) {
}

// Truncate the hash to the first 10 bytes.
return c.srtpSessionAuth.Sum(nil)[0:10], nil
return c.srtpSessionAuth.Sum(nil)[0:c.authTagLen], nil
}

func (c *Context) generateSrtcpAuthTag(buf []byte) ([]byte, error) {
Expand All @@ -289,7 +294,7 @@ func (c *Context) generateSrtcpAuthTag(buf []byte) ([]byte, error) {
return nil, err
}

return c.srtcpSessionAuth.Sum(nil)[0:10], nil
return c.srtcpSessionAuth.Sum(nil)[0:c.authTagLen], nil
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
Expand Down
10 changes: 10 additions & 0 deletions keying.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ type KeyingMaterialExporter interface {
// extracting them from DTLS. This behavior is defined in RFC5764:
// https://tools.ietf.org/html/rfc5764
func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isClient bool) error {
keyLen, err := c.Profile.keyLen()
if err != nil {
return err
}

saltLen, err := c.Profile.saltLen()
if err != nil {
return err
}

keyingMaterial, err := exporter.ExportKeyingMaterial(labelExtractorDtlsSrtp, nil, (keyLen*2)+(saltLen*2))
if err != nil {
return err
Expand Down
38 changes: 38 additions & 0 deletions protection_profile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package srtp

import "fmt"

// ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite
type ProtectionProfile uint16

// Supported protection profiles
const (
ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001
)

func (p ProtectionProfile) keyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 16, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
}

func (p ProtectionProfile) saltLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 14, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
}

func (p ProtectionProfile) authTagLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 10, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
}
17 changes: 17 additions & 0 deletions protection_profile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package srtp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestInvalidProtectionProfile(t *testing.T) {
var invalidProtectionProfile ProtectionProfile

_, err := invalidProtectionProfile.keyLen()
assert.Error(t, err)

_, err = invalidProtectionProfile.saltLen()
assert.Error(t, err)
}
5 changes: 5 additions & 0 deletions session_srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ func TestSessionSRTCPReplayProtection(t *testing.T) {
}

func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) {
authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.authTagLen()
if err != nil {
return 0, err
}

const pliPacketSize = 8
readBuffer := make([]byte, pliPacketSize+authTagSize+srtcpIndexSize)
n, _, err := stream.ReadRTCP(readBuffer)
Expand Down
6 changes: 3 additions & 3 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const maxSRTCPIndex = 0x7FFFFFFF
func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
out := allocateIfMismatch(dst, encrypted)

tailOffset := len(encrypted) - (authTagSize + srtcpIndexSize)
tailOffset := len(encrypted) - (c.authTagLen + srtcpIndexSize)
out = out[0:tailOffset]

isEncrypted := encrypted[tailOffset] >> 7
Expand All @@ -34,8 +34,8 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
return nil, errDuplicated
}

actualTag := encrypted[len(encrypted)-authTagSize:]
expectedTag, err := c.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagSize])
actualTag := encrypted[len(encrypted)-c.authTagLen:]
expectedTag, err := c.generateSrtcpAuthTag(encrypted[:len(encrypted)-c.authTagLen])
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestRTCPLifecycleInPlace(t *testing.T) {
t.Error(err)
} else if decryptHeader.Type != rtcp.TypeSenderReport {
t.Fatal("DecryptRTCP failed to populate input rtcp.Header")
} else if !bytes.Equal(decryptInput[:len(decryptInput)-(authTagSize+srtcpIndexSize)], actualDecrypted) {
} else if !bytes.Equal(decryptInput[:len(decryptInput)-(decryptContext.authTagLen+srtcpIndexSize)], actualDecrypted) {
t.Fatal("DecryptRTP failed to decrypt in place")
}

Expand All @@ -81,7 +81,7 @@ func TestRTCPLifecycleInPlace(t *testing.T) {
t.Error(err)
} else if encryptHeader.Type != rtcp.TypeSenderReport {
t.Fatal("EncryptRTCP failed to populate input rtcp.Header")
} else if !bytes.Equal(encryptInput, actualEncrypted[:len(actualEncrypted)-(authTagSize+srtcpIndexSize)]) {
} else if !bytes.Equal(encryptInput, actualEncrypted[:len(actualEncrypted)-(decryptContext.authTagLen+srtcpIndexSize)]) {
t.Fatal("EncryptRTCP failed to encrypt in place")
}

Expand Down Expand Up @@ -142,7 +142,7 @@ func TestRTCPInvalidAuthTag(t *testing.T) {
assert.Equal(decryptResult, rtcpTestDecrypted, "RTCP failed to decrypt")

// Zero out auth tag
copy(rtcpPacket[len(rtcpPacket)-authTagSize:], make([]byte, authTagSize))
copy(rtcpPacket[len(rtcpPacket)-decryptContext.authTagLen:], make([]byte, decryptContext.authTagLen))

if _, err = decryptContext.DecryptRTCP(nil, rtcpPacket, nil); err == nil {
t.Errorf("Was able to decrypt RTCP packet with invalid Auth Tag")
Expand Down Expand Up @@ -181,8 +181,8 @@ func TestRTCPReplayDetectorSeparation(t *testing.T) {
}
}

func getRTCPIndex(encrypted []byte) uint32 {
tailOffset := len(encrypted) - (authTagSize + srtcpIndexSize)
func getRTCPIndex(encrypted []byte, authTagLen int) uint32 {
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
srtcpIndexBuffer := encrypted[tailOffset : tailOffset+srtcpIndexSize]
return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31)
}
Expand Down Expand Up @@ -210,7 +210,7 @@ func TestEncryptRTCPSeparation(t *testing.T) {
}

for i, expectedIndex := range []uint32{1, 1, 2, 2} {
assert.Equal(expectedIndex, getRTCPIndex(encryptedRCTPs[i]), "RTCP index does not match")
assert.Equal(expectedIndex, getRTCPIndex(encryptedRCTPs[i], decryptContext.authTagLen), "RTCP index does not match")
}

for i, output := range encryptedRCTPs {
Expand Down
6 changes: 3 additions & 3 deletions srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header) ([]byte
return nil, errDuplicated
}

dst = growBufferSize(dst, len(ciphertext)-authTagSize)
dst = growBufferSize(dst, len(ciphertext)-c.authTagLen)

roc, updateROC := s.nextRolloverCount(header.SequenceNumber)

// Split the auth tag and the cipher text into two parts.
actualTag := ciphertext[len(ciphertext)-authTagSize:]
ciphertext = ciphertext[:len(ciphertext)-authTagSize]
actualTag := ciphertext[len(ciphertext)-c.authTagLen:]
ciphertext = ciphertext[:len(ciphertext)-c.authTagLen]

// Generate the auth tag we expect to see from the ciphertext.
expectedTag, err := c.generateSrtpAuthTag(ciphertext, roc)
Expand Down
13 changes: 11 additions & 2 deletions srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ type rtpTestCase struct {
}

func TestKeyLen(t *testing.T) {
keyLen, err := cipherContextAlgo.keyLen()
assert.NoError(t, err)

saltLen, err := cipherContextAlgo.saltLen()
assert.NoError(t, err)

if _, err := CreateContext([]byte{}, make([]byte, saltLen), cipherContextAlgo); err == nil {
t.Errorf("CreateContext accepted a 0 length key")
}
Expand Down Expand Up @@ -50,7 +56,10 @@ func TestValidSessionKeys(t *testing.T) {
t.Errorf("Session Key % 02x does not match expected % 02x", sessionKey, expectedSessionKey)
}

sessionSalt, err := c.generateSessionSalt(labelSRTPSalt)
saltLen, err := cipherContextAlgo.saltLen()
assert.NoError(t, err)

sessionSalt, err := c.generateSessionSalt(labelSRTPSalt, saltLen)
if err != nil {
t.Errorf("generateSessionSalt failed: %v", err)
} else if !bytes.Equal(sessionSalt, expectedSessionSalt) {
Expand Down Expand Up @@ -248,7 +257,7 @@ func TestRTPLifecyleNewAlloc(t *testing.T) {
actualDecrypted, err := decryptContext.DecryptRTP(nil, encryptedRaw, nil)
if err != nil {
t.Fatal(err)
} else if bytes.Equal(encryptedRaw[:len(encryptedRaw)-authTagSize], actualDecrypted) {
} else if bytes.Equal(encryptedRaw[:len(encryptedRaw)-encryptContext.authTagLen], actualDecrypted) {
t.Fatal("DecryptRTP improperly encrypted in place")
}

Expand Down

0 comments on commit 50c4c01

Please sign in to comment.