Skip to content

Commit

Permalink
Implement AEAD_AES_128_GCM
Browse files Browse the repository at this point in the history
Resolves #85
  • Loading branch information
Sean-Der committed Jul 21, 2020
1 parent 2f1e8b4 commit f871f43
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 91 deletions.
19 changes: 13 additions & 6 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,23 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts
return c, fmt.Errorf("SRTP Salt must be len %d, got %d", saltLen, masterSaltLen)
}

sCipher, err := newSrtpCipherAesCmHmacSha1(masterKey, masterSalt)
if err != nil {
return nil, err
}

c = &Context{
cipher: sCipher,
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
}

switch profile {
case ProtectionProfileAeadAes128Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(masterKey, masterSalt)
case ProtectionProfileAes128CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(masterKey, masterSalt)
default:
return nil, fmt.Errorf("no such SRTP Profile %#v", profile)
}
if err != nil {
return nil, err
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
Expand Down
80 changes: 19 additions & 61 deletions key_derivation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,84 +3,42 @@ package srtp
import (
"crypto/aes"
"encoding/binary"
"errors"
)

// All of these key derivation functions are AES-CM specific
// in the future we have multiple implementations of each of these functions
func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) {
if indexOverKdr != 0 {
// 24-bit "index DIV kdr" must be xored to prf input.
return nil, errors.New("indexOverKdr > 0 is not supported yet")
}

func generateSessionKey(label byte, masterKey, masterSalt []byte) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#appendix-B.3
// The input block for AES-CM is generated by exclusive-oring the master salt with the
// concatenation of the encryption key label 0x00 with (index DIV kdr),
// - index is 'rollover count' and DIV is 'divided by'
sessionKey := make([]byte, len(masterSalt))
copy(sessionKey, masterSalt)

labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
for i, j := len(labelAndIndexOverKdr)-1, len(sessionKey)-1; i >= 0; i, j = i-1, j-1 {
sessionKey[j] = sessionKey[j] ^ labelAndIndexOverKdr[i]
}

// then padding on the right with two null octets (which implements the multiply-by-2^16 operation, see Section 4.3.3).
sessionKey = append(sessionKey, []byte{0x00, 0x00}...)
nMasterKey := len(masterKey)
nMasterSalt := len(masterSalt)

//The resulting value is then AES-CM- encrypted using the master key to get the cipher key.
block, err := aes.NewCipher(masterKey)
if err != nil {
return nil, err
}

block.Encrypt(sessionKey, sessionKey)
return sessionKey, nil
}

func generateSessionSalt(label byte, masterKey, masterSalt []byte) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#appendix-B.3
// The input block for AES-CM is generated by exclusive-oring the master salt with
// the concatenation of the encryption salt label
sessionSalt := make([]byte, len(masterSalt))
copy(sessionSalt, masterSalt)
prfIn := make([]byte, nMasterKey)
copy(prfIn[:nMasterSalt], masterSalt)

labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
for i, j := len(labelAndIndexOverKdr)-1, len(sessionSalt)-1; i >= 0; i, j = i-1, j-1 {
sessionSalt[j] = sessionSalt[j] ^ labelAndIndexOverKdr[i]
}
prfIn[7] ^= label

// That value is padded and encrypted as above.
sessionSalt = append(sessionSalt, []byte{0x00, 0x00}...)
//The resulting value is then AES encrypted using the master key to get the cipher key.
block, err := aes.NewCipher(masterKey)
if err != nil {
return nil, err
}

block.Encrypt(sessionSalt, sessionSalt)
return sessionSalt[0:len(masterSalt)], nil
}

func generateSessionAuthTag(label byte, masterKey, masterSalt []byte) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#appendix-B.3
// We now show how the auth key is generated. The input block for AES-
// CM is generated as above, but using the authentication key label.
sessionAuthTag := make([]byte, len(masterSalt))
copy(sessionAuthTag, masterSalt)

labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
for i, j := len(labelAndIndexOverKdr)-1, len(sessionAuthTag)-1; i >= 0; i, j = i-1, j-1 {
sessionAuthTag[j] = sessionAuthTag[j] ^ labelAndIndexOverKdr[i]
out := make([]byte, ((outLen+nMasterKey)/nMasterKey)*nMasterKey)
var i uint16
for n := 0; n < outLen; n += nMasterKey {
binary.BigEndian.PutUint16(prfIn[nMasterKey-2:], i)
block.Encrypt(out[n:n+nMasterKey], prfIn)
i++
}

// That value is padded and encrypted as above.
// - We need to do multiple runs at key size (20) is larger then source
firstRun := append(sessionAuthTag, []byte{0x00, 0x00}...)
secondRun := append(sessionAuthTag, []byte{0x00, 0x01}...)
block, err := aes.NewCipher(masterKey)
if err != nil {
return nil, err
}

block.Encrypt(firstRun, firstRun)
block.Encrypt(secondRun, secondRun)
return append(firstRun, secondRun[:4]...), nil
return out[:outLen], nil
}

// Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1
Expand Down
18 changes: 15 additions & 3 deletions key_derivation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package srtp
import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func TestValidSessionKeys(t *testing.T) {
Expand All @@ -13,24 +15,34 @@ func TestValidSessionKeys(t *testing.T) {
expectedSessionSalt := []byte{0x30, 0xCB, 0xBC, 0x08, 0x86, 0x3D, 0x8C, 0x85, 0xD4, 0x9D, 0xB3, 0x4A, 0x9A, 0xE1}
expectedSessionAuthTag := []byte{0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, 0xAB, 0x49, 0xAF, 0x25, 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4}

sessionKey, err := generateSessionKey(labelSRTPEncryption, masterKey, masterSalt)
sessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey))
if err != nil {
t.Errorf("generateSessionKey failed: %v", err)
} else if !bytes.Equal(sessionKey, expectedSessionKey) {
t.Errorf("Session Key % 02x does not match expected % 02x", sessionKey, expectedSessionKey)
}

sessionSalt, err := generateSessionSalt(labelSRTPSalt, masterKey, masterSalt)
sessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt))
if err != nil {
t.Errorf("generateSessionSalt failed: %v", err)
} else if !bytes.Equal(sessionSalt, expectedSessionSalt) {
t.Errorf("Session Salt % 02x does not match expected % 02x", sessionSalt, expectedSessionSalt)
}

sessionAuthTag, err := generateSessionAuthTag(labelSRTPAuthenticationTag, masterKey, masterSalt)
authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen()
assert.NoError(t, err)

sessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen)
if err != nil {
t.Errorf("generateSessionAuthTag failed: %v", err)
} else if !bytes.Equal(sessionAuthTag, expectedSessionAuthTag) {
t.Errorf("Session Auth Tag % 02x does not match expected % 02x", sessionAuthTag, expectedSessionAuthTag)
}
}

// This test asserts that calling aesCmKeyDerivation with a non-zero indexOverKdr fails
// Currently this isn't supported, but the API makes sure we can add this in the future
func TestIndexOverKDR(t *testing.T) {
_, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, []byte{}, []byte{}, 1, 0)
assert.Error(t, err)
}
20 changes: 19 additions & 1 deletion protection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ type ProtectionProfile uint16
// Supported protection profiles
const (
ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001
ProtectionProfileAeadAes128Gcm ProtectionProfile = 0x0007
)

func (p ProtectionProfile) keyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
fallthrough
case ProtectionProfileAeadAes128Gcm:
return 16, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
Expand All @@ -23,6 +26,8 @@ func (p ProtectionProfile) saltLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 14, nil
case ProtectionProfileAeadAes128Gcm:
return 12, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
Expand All @@ -31,7 +36,20 @@ func (p ProtectionProfile) saltLen() (int, error) {
func (p ProtectionProfile) authTagLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 10, nil
return (&srtpCipherAesCmHmacSha1{}).authTagLen(), nil
case ProtectionProfileAeadAes128Gcm:
return (&srtpCipherAeadAesGcm{}).authTagLen(), nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
}

func (p ProtectionProfile) authKeyLen() (int, error) {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return 20, nil
case ProtectionProfileAeadAes128Gcm:
return 0, nil
default:
return 0, fmt.Errorf("no such ProtectionProfile %#v", p)
}
Expand Down
15 changes: 8 additions & 7 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package srtp

import (
"encoding/binary"
"fmt"

"github.com/pion/rtcp"
)
Expand All @@ -10,14 +11,15 @@ const maxSRTCPIndex = 0x7FFFFFFF

func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
out := allocateIfMismatch(dst, encrypted)

tailOffset := len(encrypted) - (c.cipher.authTagLen() + srtcpIndexSize)
if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 {

if tailOffset < 0 {
return nil, fmt.Errorf("%d is too short to be a valid RTCP packet", len(encrypted))
} else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 {
return out, nil
}

srtcpIndexBuffer := encrypted[tailOffset : tailOffset+srtcpIndexSize]
index := binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31)
index := c.cipher.getRTCPIndex(encrypted)
ssrc := binary.BigEndian.Uint32(encrypted[4:])

s := c.getSRTCPSSRCState(ssrc)
Expand Down Expand Up @@ -49,8 +51,7 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt
}

func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
out := allocateIfMismatch(dst, decrypted)
ssrc := binary.BigEndian.Uint32(out[4:])
ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)

// We roll over early because MSB is used for marking as encrypted
Expand All @@ -59,7 +60,7 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
s.srtcpIndex = 0
}

return c.cipher.encryptRTCP(out, decrypted, s.srtcpIndex, ssrc)
return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc)
}

// EncryptRTCP Encrypts a RTCP packet
Expand Down
8 changes: 2 additions & 6 deletions srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,17 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) (
header = &rtp.Header{}
}

err := header.Unmarshal(plaintext)
if err != nil {
if err := header.Unmarshal(plaintext); err != nil {
return nil, err
}

return c.encryptRTP(dst, header, plaintext[header.PayloadOffset:])
}

// encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided.
// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned.
// If the dst buffer does not have the capacity, a new one will be allocated and returned.
// Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload.
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) {
// Grow the given buffer to fit the output.
dst = growBufferSize(dst, header.MarshalSize()+len(payload)+c.cipher.authTagLen())

s := c.getSRTPSSRCState(header.SSRC)
roc, updateROC := s.nextRolloverCount(header.SequenceNumber)
updateROC()
Expand Down
1 change: 1 addition & 0 deletions srtp_cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import "github.com/pion/rtp"
// of the SRTP Specific ciphers
type srtpCipher interface {
authTagLen() int
getRTCPIndex([]byte) uint32

encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error)
encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error)
Expand Down
Loading

0 comments on commit f871f43

Please sign in to comment.