Skip to content

Commit

Permalink
Merge pull request #651 from sselph/master
Browse files Browse the repository at this point in the history
A few iterations on the end-to-end code
  • Loading branch information
Emmanuel T Odeke committed May 23, 2016
2 parents 0086aa9 + 00e104a commit 87197b6
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 47 deletions.
49 changes: 38 additions & 11 deletions src/dcrypto/dcrypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
"hash"
"io"

"github.com/odeke-em/drive/src/dcrypto/v1"
Expand All @@ -28,11 +29,14 @@ import (
// Version is the version of the en/decryption library used.
type Version uint32

// Decrypter is a function that creates a decrypter.
type Decrypter func(io.Reader, []byte) (io.ReadCloser, error)
// decrypter is a function that creates a decrypter.
type decrypter func(io.Reader, []byte) (io.ReadCloser, error)

// Encrypter is a function that creates a encrypter.
type Encrypter func(io.Reader, []byte) (io.Reader, error)
// encrypter is a function that creates a encrypter.
type encrypter func(io.Reader, []byte) (io.Reader, error)

// hasher is a function that returns the hash of a plaintext as if it were encrypted.
type hasher func(io.Reader, io.Reader, []byte, hash.Hash) ([]byte, error)

// These are the different versions of the en/decryption library.
const (
Expand All @@ -42,20 +46,29 @@ const (
// PreferedVersion is the prefered version of encryption.
const PreferedVersion = V1

var encrypters map[Version]Encrypter
var decrypters map[Version]Decrypter
var encrypters map[Version]encrypter
var decrypters map[Version]decrypter
var hashers map[Version]hasher

// MaxHeaderSize is the maximum header size of all versions.
// This many bytes at the beginning of a file should be enough to compute
// a hash of a local file.
var MaxHeaderSize = v1.HeaderSize + 4

func init() {
decrypters = map[Version]Decrypter{
decrypters = map[Version]decrypter{
V1: v1.NewDecryptReader,
}

encrypters = map[Version]Encrypter{
encrypters = map[Version]encrypter{
V1: v1.NewEncryptReader,
}
hashers = map[Version]hasher{
V1: v1.Hash,
}
}

// NewEncrypter returns an Encrypter using the PreferedVersion.
// NewEncrypter returns an encrypting reader using the PreferedVersion.
func NewEncrypter(r io.Reader, password []byte) (io.Reader, error) {
v, err := writeVersion(PreferedVersion)
if err != nil {
Expand All @@ -69,10 +82,10 @@ func NewEncrypter(r io.Reader, password []byte) (io.Reader, error) {
if err != nil {
return nil, err
}
return io.MultiReader(bytes.NewBuffer(v), encReader), nil
return io.MultiReader(bytes.NewReader(v), encReader), nil
}

// NewDecrypter returns a Decrypter based on the version used to encrypt.
// NewDecrypter returns a decrypting reader based on the version used to encrypt.
func NewDecrypter(r io.Reader, password []byte) (io.ReadCloser, error) {
version, err := readVersion(r)
if err != nil {
Expand All @@ -85,6 +98,20 @@ func NewDecrypter(r io.Reader, password []byte) (io.ReadCloser, error) {
return decrypterFn(r, password)
}

// Hash will hash of plaintext based on the header of the encrypted file and returns the hash Sum.
func Hash(r io.Reader, header io.Reader, password []byte, hashFunc func() hash.Hash) ([]byte, error) {
h := hashFunc()
version, err := readVersion(io.TeeReader(header, h))
if err != nil {
return nil, err
}
hasherFn, ok := hashers[version]
if !ok {
return nil, fmt.Errorf("unknown hasher for version(%d)", version)
}
return hasherFn(r, header, password, h)
}

// writeVersion converts a Version to a []byte.
func writeVersion(i Version) ([]byte, error) {
buf := new(bytes.Buffer)
Expand Down
40 changes: 38 additions & 2 deletions src/dcrypto/dcrypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package dcrypto_test
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"io"
"io/ioutil"
"testing"

Expand Down Expand Up @@ -54,7 +56,7 @@ func TestRoundTrip(t *testing.T) {
t.Errorf("randBytes(%d) => %q; want nil", size, err)
continue
}
encReader, err := dcrypto.NewEncrypter(bytes.NewBuffer(b), password)
encReader, err := dcrypto.NewEncrypter(bytes.NewReader(b), password)
if err != nil {
t.Errorf("NewEncrypter() => %q; want nil", err)
continue
Expand All @@ -64,7 +66,7 @@ func TestRoundTrip(t *testing.T) {
t.Errorf("ioutil.ReadAll(*Encrypter) => %q; want nil", err)
continue
}
decReader, err := dcrypto.NewDecrypter(bytes.NewBuffer(cipher), password)
decReader, err := dcrypto.NewDecrypter(bytes.NewReader(cipher), password)
if err != nil {
t.Errorf("NewDecrypter() => %q; want nil", err)
continue
Expand All @@ -81,3 +83,37 @@ func TestRoundTrip(t *testing.T) {
}
}
}

func TestHash(t *testing.T) {
password := []byte("test")
sizes := []int{0, 24, 1337, 66560}
for _, size := range sizes {
h := sha256.New()
t.Logf("Testing file of size: %db, with password: %s", size, password)
b, err := randBytes(size)
if err != nil {
t.Errorf("randBytes(%d) => %q; want nil", size, err)
continue
}
encReader, err := dcrypto.NewEncrypter(bytes.NewReader(b), password)
if err != nil {
t.Errorf("NewEncryper() => %q; want nil", err)
continue
}
cipher, err := ioutil.ReadAll(io.TeeReader(encReader, h))
if err != nil {
t.Errorf("ioutil.ReadAll(*EncryptReader) => %q; want nil", err)
continue
}
want := h.Sum(nil)
got, err := dcrypto.Hash(bytes.NewReader(b), bytes.NewReader(cipher[0:dcrypto.MaxHeaderSize]), password, sha256.New)
if err != nil {
t.Errorf("Hash() => err = %q; want nil", err)
continue
}
if !bytes.Equal(got, want) {
t.Errorf("Hash() => %v; want %v", got, want)
}
}

}
128 changes: 100 additions & 28 deletions src/dcrypto/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/sha512"
"encoding/binary"
"errors"
"hash"
"io"
Expand Down Expand Up @@ -58,7 +59,7 @@ const (
// The number of iterations to use in for key generation
// See N value in https://godoc.org/golang.org/x/crypto/scrypt#Key
// Must be a power of 2.
scryptIterations = 262144 // 2^18
scryptIterations int32 = 262144 // 2^18
)

const _16KB = 16 * 1024
Expand All @@ -69,6 +70,13 @@ var (

// The amount of key material we need.
keySize = hmacKeySize + aesKeySize

// The size of the Header.
HeaderSize = 4 + saltSize + blockSize

// The overhead added to the file by using this library.
// OverHead + len(plaintext) == len(ciphertext)
OverHead = HeaderSize + hmacSize
)

var DecryptErr = errors.New("message corrupt or incorrect password")
Expand All @@ -91,27 +99,51 @@ func keys(pass, salt []byte, iterations int) (aesKey, hmacKey []byte, err error)
return aesKey, hmacKey, nil
}

// Make sure we implement io.ReadWriter.
var _ io.ReadWriter = &hashReadWriter{}

// hashReadWriter hashes on write and on read finalizes the hash and returns it.
// Writes after a Read will return an error.
type hashReadWriter struct {
hash hash.Hash
done bool
sum io.Reader
}

// Write implements io.Writer
func (h *hashReadWriter) Write(p []byte) (int, error) {
if h.done {
return 0, errors.New("writing to hashReadWriter after read is not allowed")
}
return h.hash.Write(p)
}

// Read implements io.Reader.
func (h *hashReadWriter) Read(p []byte) (int, error) {
if !h.done {
h.done = true
h.sum = bytes.NewBuffer(h.hash.Sum(nil))
h.sum = bytes.NewReader(h.hash.Sum(nil))
}
return h.sum.Read(p)
}

// encInt32 will encode a int32 in to a byte slice.
func encInt32(i int32) ([]byte, error) {
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, i); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

// decInt32 will read an int32 from a reader and return the byte slice and the int32.
func decInt32(r io.Reader) (b []byte, i int32, err error) {
buf := new(bytes.Buffer)
tr := io.TeeReader(r, buf)
err = binary.Read(tr, binary.LittleEndian, &i)
return buf.Bytes(), i, err
}

// NewEncryptReader returns an io.Reader wrapping the provided io.Reader.
// It uses a user provided password and a random salt to derive keys.
// If the key is provided interactively, it should be verified since there
Expand All @@ -127,26 +159,63 @@ func NewEncryptReader(r io.Reader, pass []byte) (io.Reader, error) {
// newEncryptReader returns a encryptReader wrapping an io.Reader.
// It uses a user provided password and the provided salt iterated the
// provided number of times to derive keys.
func newEncryptReader(r io.Reader, pass, salt []byte, iterations int) (io.Reader, error) {
aesKey, hmacKey, err := keys(pass, salt, iterations)
func newEncryptReader(r io.Reader, pass, salt []byte, iterations int32) (io.Reader, error) {
itersAsBytes, err := encInt32(iterations)
if err != nil {
return nil, err
}
b, err := aes.NewCipher(aesKey)
aesKey, hmacKey, err := keys(pass, salt, int(iterations))
if err != nil {
return nil, err
}
h := hmac.New(hashFunc, hmacKey)
iv, err := randBytes(blockSize)
if err != nil {
return nil, err
}
var header []byte
header = append(header, itersAsBytes...)
header = append(header, salt...)
header = append(header, iv...)
return encrypter(r, aesKey, hmacKey, iv, header)
}

// encrypter returns the encrypted reader pased on the keys and IV provided.
func encrypter(r io.Reader, aesKey, hmacKey, iv, header []byte) (io.Reader, error) {
b, err := aes.NewCipher(aesKey)
if err != nil {
return nil, err
}
h := hmac.New(hashFunc, hmacKey)
hr := &hashReadWriter{hash: h}
sr := &cipher.StreamReader{R: r, S: cipher.NewCTR(b, iv)}
var header []byte
return io.MultiReader(io.TeeReader(io.MultiReader(bytes.NewReader(header), sr), hr), hr), nil
}

// decodeHeader decodes the header of the reader.
// It returns the keys, IV, and original header using the password and iterations in the reader.
func decodeHeader(r io.Reader, password []byte) (aesKey, hmacKey, iv, header []byte, err error) {
itersAsBytes, iterations, err := decInt32(r)
if err != nil {
return nil, nil, nil, nil, err
}
salt := make([]byte, saltSize)
iv = make([]byte, blockSize)
_, err = io.ReadFull(r, salt)
if err != nil {
return nil, nil, nil, nil, err
}
_, err = io.ReadFull(r, iv)
if err != nil {
return nil, nil, nil, nil, err
}
aesKey, hmacKey, err = keys(password, salt, int(iterations))
if err != nil {
return nil, nil, nil, nil, err
}
header = append(header, itersAsBytes...)
header = append(header, salt...)
header = append(header, iv...)
return io.MultiReader(io.TeeReader(io.MultiReader(bytes.NewBuffer(header), sr), hr), hr), nil
return aesKey, hmacKey, iv, header, err
}

// decryptReader wraps a io.Reader decrypting its content.
Expand All @@ -160,30 +229,14 @@ type decryptReader struct {
// hash the contents to verify that it is safe to decrypt.
// If the file is athenticated, the DecryptReader will be returned and
// the resulting bytes will be the plaintext.
func NewDecryptReader(r io.Reader, pass []byte) (io.ReadCloser, error) {
return newDecryptReader(r, pass, scryptIterations)
}

func newDecryptReader(r io.Reader, pass []byte, iterations int) (d *decryptReader, err error) {
salt := make([]byte, saltSize)
iv := make([]byte, blockSize)
func NewDecryptReader(r io.Reader, pass []byte) (d io.ReadCloser, err error) {
mac := make([]byte, hmacSize)
_, err = io.ReadFull(r, salt)
if err != nil {
return nil, err
}
_, err = io.ReadFull(r, iv)
if err != nil {
return nil, err
}
aesKey, hmacKey, err := keys(pass, salt, iterations)
aesKey, hmacKey, iv, header, err := decodeHeader(r, pass)
h := hmac.New(hashFunc, hmacKey)
h.Write(header)
if err != nil {
return nil, err
}
// Start Verifying the HMAC of the message.
h := hmac.New(hashFunc, hmacKey)
h.Write(salt)
h.Write(iv)
dst, err := tmpfile.New(&tmpfile.Context{
Dir: os.TempDir(),
Suffix: "drive-encrypted-",
Expand Down Expand Up @@ -214,6 +267,9 @@ func newDecryptReader(r io.Reader, pass []byte, iterations int) (d *decryptReade
}
if err == io.EOF {
left := buf.Buffered()
if left < hmacSize {
return nil, DecryptErr
}
copy(mac, b[left-hmacSize:left])
_, err = io.CopyN(w, buf, int64(left-hmacSize))
if err != nil {
Expand Down Expand Up @@ -242,3 +298,19 @@ func (d *decryptReader) Read(dst []byte) (int, error) {
func (d *decryptReader) Close() error {
return d.tmpFile.Done()
}

// Hash hashes the plaintext based on the header of the encrypted file and returns the hash Sum.
func Hash(plainTextR io.Reader, headerR io.Reader, password []byte, h hash.Hash) ([]byte, error) {
aesKey, hmacKey, iv, eHeader, err := decodeHeader(headerR, password)
if err != nil {
return nil, err
}
encReader, err := encrypter(plainTextR, aesKey, hmacKey, iv, eHeader)
if err != nil {
return nil, err
}
if _, err := io.Copy(h, encReader); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
Loading

0 comments on commit 87197b6

Please sign in to comment.