diff --git a/context.go b/context.go index f9f024b..de13cad 100644 --- a/context.go +++ b/context.go @@ -75,16 +75,23 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts return c, fmt.Errorf("SRTP Salt must be len %d, got %d", saltLen, masterSaltLen) } - sCipher, err := newSrtpCipherAesCmHmacSha1(masterKey, masterSalt) - if err != nil { - return nil, err - } - c = &Context{ - cipher: sCipher, srtpSSRCStates: map[uint32]*srtpSSRCState{}, srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, } + + switch profile { + case ProtectionProfileAeadAes128Gcm: + c.cipher, err = newSrtpCipherAeadAesGcm(masterKey, masterSalt) + case ProtectionProfileAes128CmHmacSha1_80: + c.cipher, err = newSrtpCipherAesCmHmacSha1(masterKey, masterSalt) + default: + return nil, fmt.Errorf("no such SRTP Profile %#v", profile) + } + if err != nil { + return nil, err + } + for _, o := range append( []ContextOption{ // Default options SRTPNoReplayProtection(), diff --git a/key_derivation.go b/key_derivation.go index cdfe1ad..b78dca4 100644 --- a/key_derivation.go +++ b/key_derivation.go @@ -3,84 +3,42 @@ package srtp import ( "crypto/aes" "encoding/binary" + "errors" ) -// All of these key derivation functions are AES-CM specific -// in the future we have multiple implementations of each of these functions +func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) { + if indexOverKdr != 0 { + // 24-bit "index DIV kdr" must be xored to prf input. + return nil, errors.New("indexOverKdr > 0 is not supported yet") + } -func generateSessionKey(label byte, masterKey, masterSalt []byte) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#appendix-B.3 // The input block for AES-CM is generated by exclusive-oring the master salt with the // concatenation of the encryption key label 0x00 with (index DIV kdr), // - index is 'rollover count' and DIV is 'divided by' - sessionKey := make([]byte, len(masterSalt)) - copy(sessionKey, masterSalt) - - labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - for i, j := len(labelAndIndexOverKdr)-1, len(sessionKey)-1; i >= 0; i, j = i-1, j-1 { - sessionKey[j] = sessionKey[j] ^ labelAndIndexOverKdr[i] - } - // then padding on the right with two null octets (which implements the multiply-by-2^16 operation, see Section 4.3.3). - sessionKey = append(sessionKey, []byte{0x00, 0x00}...) + nMasterKey := len(masterKey) + nMasterSalt := len(masterSalt) - //The resulting value is then AES-CM- encrypted using the master key to get the cipher key. - block, err := aes.NewCipher(masterKey) - if err != nil { - return nil, err - } - - block.Encrypt(sessionKey, sessionKey) - return sessionKey, nil -} - -func generateSessionSalt(label byte, masterKey, masterSalt []byte) ([]byte, error) { - // https://tools.ietf.org/html/rfc3711#appendix-B.3 - // The input block for AES-CM is generated by exclusive-oring the master salt with - // the concatenation of the encryption salt label - sessionSalt := make([]byte, len(masterSalt)) - copy(sessionSalt, masterSalt) + prfIn := make([]byte, nMasterKey) + copy(prfIn[:nMasterSalt], masterSalt) - labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - for i, j := len(labelAndIndexOverKdr)-1, len(sessionSalt)-1; i >= 0; i, j = i-1, j-1 { - sessionSalt[j] = sessionSalt[j] ^ labelAndIndexOverKdr[i] - } + prfIn[7] ^= label - // That value is padded and encrypted as above. - sessionSalt = append(sessionSalt, []byte{0x00, 0x00}...) + //The resulting value is then AES encrypted using the master key to get the cipher key. block, err := aes.NewCipher(masterKey) if err != nil { return nil, err } - block.Encrypt(sessionSalt, sessionSalt) - return sessionSalt[0:len(masterSalt)], nil -} - -func generateSessionAuthTag(label byte, masterKey, masterSalt []byte) ([]byte, error) { - // https://tools.ietf.org/html/rfc3711#appendix-B.3 - // We now show how the auth key is generated. The input block for AES- - // CM is generated as above, but using the authentication key label. - sessionAuthTag := make([]byte, len(masterSalt)) - copy(sessionAuthTag, masterSalt) - - labelAndIndexOverKdr := []byte{label, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - for i, j := len(labelAndIndexOverKdr)-1, len(sessionAuthTag)-1; i >= 0; i, j = i-1, j-1 { - sessionAuthTag[j] = sessionAuthTag[j] ^ labelAndIndexOverKdr[i] + out := make([]byte, ((outLen+nMasterKey)/nMasterKey)*nMasterKey) + var i uint16 + for n := 0; n < outLen; n += nMasterKey { + binary.BigEndian.PutUint16(prfIn[nMasterKey-2:], i) + block.Encrypt(out[n:n+nMasterKey], prfIn) + i++ } - - // That value is padded and encrypted as above. - // - We need to do multiple runs at key size (20) is larger then source - firstRun := append(sessionAuthTag, []byte{0x00, 0x00}...) - secondRun := append(sessionAuthTag, []byte{0x00, 0x01}...) - block, err := aes.NewCipher(masterKey) - if err != nil { - return nil, err - } - - block.Encrypt(firstRun, firstRun) - block.Encrypt(secondRun, secondRun) - return append(firstRun, secondRun[:4]...), nil + return out[:outLen], nil } // Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1 diff --git a/key_derivation_test.go b/key_derivation_test.go index 8acdbbe..42bf568 100644 --- a/key_derivation_test.go +++ b/key_derivation_test.go @@ -3,6 +3,8 @@ package srtp import ( "bytes" "testing" + + "github.com/stretchr/testify/assert" ) func TestValidSessionKeys(t *testing.T) { @@ -13,24 +15,34 @@ func TestValidSessionKeys(t *testing.T) { expectedSessionSalt := []byte{0x30, 0xCB, 0xBC, 0x08, 0x86, 0x3D, 0x8C, 0x85, 0xD4, 0x9D, 0xB3, 0x4A, 0x9A, 0xE1} expectedSessionAuthTag := []byte{0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, 0xAB, 0x49, 0xAF, 0x25, 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4} - sessionKey, err := generateSessionKey(labelSRTPEncryption, masterKey, masterSalt) + sessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { t.Errorf("generateSessionKey failed: %v", err) } else if !bytes.Equal(sessionKey, expectedSessionKey) { t.Errorf("Session Key % 02x does not match expected % 02x", sessionKey, expectedSessionKey) } - sessionSalt, err := generateSessionSalt(labelSRTPSalt, masterKey, masterSalt) + sessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) if err != nil { t.Errorf("generateSessionSalt failed: %v", err) } else if !bytes.Equal(sessionSalt, expectedSessionSalt) { t.Errorf("Session Salt % 02x does not match expected % 02x", sessionSalt, expectedSessionSalt) } - sessionAuthTag, err := generateSessionAuthTag(labelSRTPAuthenticationTag, masterKey, masterSalt) + authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen() + assert.NoError(t, err) + + sessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) if err != nil { t.Errorf("generateSessionAuthTag failed: %v", err) } else if !bytes.Equal(sessionAuthTag, expectedSessionAuthTag) { t.Errorf("Session Auth Tag % 02x does not match expected % 02x", sessionAuthTag, expectedSessionAuthTag) } } + +// This test asserts that calling aesCmKeyDerivation with a non-zero indexOverKdr fails +// Currently this isn't supported, but the API makes sure we can add this in the future +func TestIndexOverKDR(t *testing.T) { + _, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, []byte{}, []byte{}, 1, 0) + assert.Error(t, err) +} diff --git a/protection_profile.go b/protection_profile.go index 0fb058c..9a54bf6 100644 --- a/protection_profile.go +++ b/protection_profile.go @@ -8,11 +8,14 @@ type ProtectionProfile uint16 // Supported protection profiles const ( ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001 + ProtectionProfileAeadAes128Gcm ProtectionProfile = 0x0007 ) func (p ProtectionProfile) keyLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: + fallthrough + case ProtectionProfileAeadAes128Gcm: return 16, nil default: return 0, fmt.Errorf("no such ProtectionProfile %#v", p) @@ -23,6 +26,8 @@ func (p ProtectionProfile) saltLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: return 14, nil + case ProtectionProfileAeadAes128Gcm: + return 12, nil default: return 0, fmt.Errorf("no such ProtectionProfile %#v", p) } @@ -31,7 +36,20 @@ func (p ProtectionProfile) saltLen() (int, error) { func (p ProtectionProfile) authTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: - return 10, nil + return (&srtpCipherAesCmHmacSha1{}).authTagLen(), nil + case ProtectionProfileAeadAes128Gcm: + return (&srtpCipherAeadAesGcm{}).authTagLen(), nil + default: + return 0, fmt.Errorf("no such ProtectionProfile %#v", p) + } +} + +func (p ProtectionProfile) authKeyLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_80: + return 20, nil + case ProtectionProfileAeadAes128Gcm: + return 0, nil default: return 0, fmt.Errorf("no such ProtectionProfile %#v", p) } diff --git a/srtcp.go b/srtcp.go index 8d94b52..9e377ca 100644 --- a/srtcp.go +++ b/srtcp.go @@ -2,6 +2,7 @@ package srtp import ( "encoding/binary" + "fmt" "github.com/pion/rtcp" ) @@ -10,14 +11,15 @@ const maxSRTCPIndex = 0x7FFFFFFF func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { out := allocateIfMismatch(dst, encrypted) - tailOffset := len(encrypted) - (c.cipher.authTagLen() + srtcpIndexSize) - if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { + + if tailOffset < 0 { + return nil, fmt.Errorf("%d is too short to be a valid RTCP packet", len(encrypted)) + } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil } - srtcpIndexBuffer := encrypted[tailOffset : tailOffset+srtcpIndexSize] - index := binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) + index := c.cipher.getRTCPIndex(encrypted) ssrc := binary.BigEndian.Uint32(encrypted[4:]) s := c.getSRTCPSSRCState(ssrc) @@ -49,8 +51,7 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt } func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { - out := allocateIfMismatch(dst, decrypted) - ssrc := binary.BigEndian.Uint32(out[4:]) + ssrc := binary.BigEndian.Uint32(decrypted[4:]) s := c.getSRTCPSSRCState(ssrc) // We roll over early because MSB is used for marking as encrypted @@ -59,7 +60,7 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { s.srtcpIndex = 0 } - return c.cipher.encryptRTCP(out, decrypted, s.srtcpIndex, ssrc) + return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc) } // EncryptRTCP Encrypts a RTCP packet diff --git a/srtp.go b/srtp.go index 47ba5e1..f9f2b71 100644 --- a/srtp.go +++ b/srtp.go @@ -47,8 +47,7 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ( header = &rtp.Header{} } - err := header.Unmarshal(plaintext) - if err != nil { + if err := header.Unmarshal(plaintext); err != nil { return nil, err } @@ -56,12 +55,9 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ( } // encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. -// If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned. +// If the dst buffer does not have the capacity, a new one will be allocated and returned. // Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload. func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) { - // Grow the given buffer to fit the output. - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+c.cipher.authTagLen()) - s := c.getSRTPSSRCState(header.SSRC) roc, updateROC := s.nextRolloverCount(header.SequenceNumber) updateROC() diff --git a/srtp_cipher.go b/srtp_cipher.go index dbc91fb..6c1fda0 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -6,6 +6,7 @@ import "github.com/pion/rtp" // of the SRTP Specific ciphers type srtpCipher interface { authTagLen() int + getRTCPIndex([]byte) uint32 encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go new file mode 100644 index 0000000..58e1b8f --- /dev/null +++ b/srtp_cipher_aead_aes_gcm.go @@ -0,0 +1,172 @@ +package srtp + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/binary" + + "github.com/pion/rtp" +) + +const ( + rtcpEncryptionFlag = 0x80 +) + +type srtpCipherAeadAesGcm struct { + srtpCipher, srtcpCipher cipher.AEAD + + srtpSessionSalt, srtcpSessionSalt []byte +} + +func newSrtpCipherAeadAesGcm(masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { + s := &srtpCipherAeadAesGcm{} + + srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } + + srtpBlock, err := aes.NewCipher(srtpSessionKey) + if err != nil { + return nil, err + } + + s.srtpCipher, err = cipher.NewGCM(srtpBlock) + if err != nil { + return nil, err + } + + srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) + if err != nil { + return nil, err + } + + srtcpBlock, err := aes.NewCipher(srtcpSessionKey) + if err != nil { + return nil, err + } + + s.srtcpCipher, err = cipher.NewGCM(srtcpBlock) + if err != nil { + return nil, err + } + + if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + return nil, err + } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { + return nil, err + } + + return s, nil +} + +func (s *srtpCipherAeadAesGcm) authTagLen() int { + return 16 +} + +func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { + hdr, err := header.Marshal() + if err != nil { + return nil, err + } + + iv := s.rtpInitializationVector(header, roc) + out := s.srtpCipher.Seal(nil, iv, payload, hdr) + return append(hdr, out...), nil +} + +func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, roc uint32) ([]byte, error) { + iv := s.rtpInitializationVector(header, roc) + + out, err := s.srtpCipher.Open(nil, iv, ciphertext[header.PayloadOffset:], ciphertext[:header.PayloadOffset]) + if err != nil { + return nil, err + } + + out = append(make([]byte, header.PayloadOffset), out...) + copy(out, ciphertext[:header.PayloadOffset]) + + return out, nil +} + +func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { + iv := s.rtcpInitializationVector(srtcpIndex, ssrc) + aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) + + out := s.srtcpCipher.Seal(nil, iv, decrypted[8:], aad) + + out = append(make([]byte, 8), out...) + copy(out, decrypted[:8]) + out = append(out, aad[8:]...) + + return out, nil +} + +func (s *srtpCipherAeadAesGcm) decryptRTCP(out, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { + iv := s.rtcpInitializationVector(srtcpIndex, ssrc) + aad := s.rtcpAdditionalAuthenticatedData(encrypted, srtcpIndex) + + decrypted, err := s.srtcpCipher.Open(nil, iv, encrypted[8:len(encrypted)-srtcpIndexSize], aad) + if err != nil { + return nil, err + } + + decrypted = append(encrypted[:8], decrypted...) + return decrypted, nil +} + +// The 12-octet IV used by AES-GCM SRTP is formed by first concatenating +// 2 octets of zeroes, the 4-octet SSRC, the 4-octet rollover counter +// (ROC), and the 2-octet sequence number (SEQ). The resulting 12-octet +// value is then XORed to the 12-octet salt to form the 12-octet IV. +// +// https://tools.ietf.org/html/rfc7714#section-8.1 +func (s *srtpCipherAeadAesGcm) rtpInitializationVector(header *rtp.Header, roc uint32) []byte { + iv := make([]byte, 12) + binary.BigEndian.PutUint32(iv[2:], header.SSRC) + binary.BigEndian.PutUint32(iv[6:], roc) + binary.BigEndian.PutUint16(iv[10:], header.SequenceNumber) + + for i := range iv { + iv[i] ^= s.srtpSessionSalt[i] + } + return iv +} + +// The 12-octet IV used by AES-GCM SRTCP is formed by first +// concatenating 2 octets of zeroes, the 4-octet SSRC identifier, +// 2 octets of zeroes, a single "0" bit, and the 31-bit SRTCP index. +// The resulting 12-octet value is then XORed to the 12-octet salt to +// form the 12-octet IV. +// +// https://tools.ietf.org/html/rfc7714#section-9.1 +func (s *srtpCipherAeadAesGcm) rtcpInitializationVector(srtcpIndex uint32, ssrc uint32) []byte { + iv := make([]byte, 12) + + binary.BigEndian.PutUint32(iv[2:], ssrc) + binary.BigEndian.PutUint32(iv[8:], srtcpIndex) + + for i := range iv { + iv[i] ^= s.srtcpSessionSalt[i] + } + return iv +} + +// In an SRTCP packet, a 1-bit Encryption flag is prepended to the +// 31-bit SRTCP index to form a 32-bit value we shall call the +// "ESRTCP word" +// +// https://tools.ietf.org/html/rfc7714#section-17 +func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte, srtcpIndex uint32) []byte { + aad := make([]byte, 12) + + copy(aad, rtcpPacket[:8]) + binary.BigEndian.PutUint32(aad[8:], srtcpIndex) + aad[8] |= rtcpEncryptionFlag + + return aad +} + +func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { + return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) +} diff --git a/srtp_cipher_aead_aes_gcm_test.go b/srtp_cipher_aead_aes_gcm_test.go new file mode 100644 index 0000000..6c1c603 --- /dev/null +++ b/srtp_cipher_aead_aes_gcm_test.go @@ -0,0 +1,84 @@ +package srtp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSrtpCipherAedAesGcm(t *testing.T) { + decryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, + } + encryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, 0xde, + 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, + 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, 0x68, 0x26, + 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, + 0x27, 0xe8, 0xa3, 0x92, + } + decryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + } + encryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, + 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, + 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, + 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, + 0x80, 0x00, 0x00, 0x01, + } + + masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} + masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab} + + t.Run("Encrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRTPPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, encryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRtcpPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, encryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) +} diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index d6e8312..a6eedf8 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -25,32 +25,37 @@ type srtpCipherAesCmHmacSha1 struct { func newSrtpCipherAesCmHmacSha1(masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{} - srtpSessionKey, err := generateSessionKey(labelSRTPEncryption, masterKey, masterSalt) + srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } else if s.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil { return nil, err } - srtcpSessionKey, err := generateSessionKey(labelSRTCPEncryption, masterKey, masterSalt) + srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } else if s.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil { return nil, err } - if s.srtpSessionSalt, err = generateSessionSalt(labelSRTPSalt, masterKey, masterSalt); err != nil { + if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err - } else if s.srtcpSessionSalt, err = generateSessionSalt(labelSRTCPSalt, masterKey, masterSalt); err != nil { + } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err } - srtpSessionAuthTag, err := generateSessionAuthTag(labelSRTPAuthenticationTag, masterKey, masterSalt) + authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen() if err != nil { return nil, err } - srtcpSessionAuthTag, err := generateSessionAuthTag(labelSRTCPAuthenticationTag, masterKey, masterSalt) + srtpSessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) + if err != nil { + return nil, err + } + + srtcpSessionAuthTag, err := aesCmKeyDerivation(labelSRTCPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) if err != nil { return nil, err } @@ -65,6 +70,9 @@ func (s *srtpCipherAesCmHmacSha1) authTagLen() int { } func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { + // Grow the given buffer to fit the output. + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+s.authTagLen()) + // Copy the header unencrypted. n, err := header.MarshalTo(dst) if err != nil { @@ -117,6 +125,8 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp } func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { + dst = allocateIfMismatch(dst, decrypted) + // Encrypt everything after header stream := cipher.NewCTR(s.srtcpBlock, generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt)) stream.XORKeyStream(dst[8:], dst[8:]) @@ -207,3 +217,9 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro return s.srtcpSessionAuth.Sum(nil)[0:s.authTagLen()], nil } + +func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { + tailOffset := len(in) - (s.authTagLen() + srtcpIndexSize) + srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] + return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) +} diff --git a/srtp_test.go b/srtp_test.go index 027f2a2..2aed1fb 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -40,7 +40,7 @@ func TestValidPacketCounter(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} - srtpSessionSalt, err := generateSessionSalt(labelSRTPSalt, masterKey, masterSalt) + srtpSessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) assert.NoError(t, err) s := &srtpSSRCState{ssrc: 4160032510}