Skip to content

Commit

Permalink
chacha: Use Overlapping in the implementation of the fallback impl.
Browse files Browse the repository at this point in the history
Eliminate all of the `unsafe` in the fallback implementation.
  • Loading branch information
briansmith committed Jan 9, 2025
1 parent 504685d commit d2e401f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 20 deletions.
31 changes: 17 additions & 14 deletions src/aead/chacha/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
// Adapted from the public domain, estream code by D. Bernstein.
// Adapted from the BoringSSL crypto/chacha/chacha.c.

use super::{Counter, Key, Overlapping, BLOCK_LEN};
use super::{super::overlapping::IndexError, Counter, Key, Overlapping, BLOCK_LEN};
use crate::{constant_time, polyfill::sliceutil};
use core::{mem::size_of, slice};
use core::mem::size_of;

pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_>) {
pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, mut in_out: Overlapping<'_>) {
const SIGMA: [u32; 4] = [
u32::from_le_bytes(*b"expa"),
u32::from_le_bytes(*b"nd 3"),
Expand All @@ -35,31 +35,34 @@ pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_
key[6], key[7], counter[0], counter[1], counter[2], counter[3],
];

let (mut input, mut output, mut in_out_len) = in_out.into_input_output_len();
let mut in_out_len = in_out.len();

let mut buf = [0u8; BLOCK_LEN];
while in_out_len > 0 {
chacha_core(&mut buf, &state);
state[12] += 1;

debug_assert_eq!(in_out_len, in_out.len());

// Both branches do the same thing, but the duplication helps the
// compiler optimize (vectorize) the `BLOCK_LEN` case.
if in_out_len >= BLOCK_LEN {
let input = unsafe { slice::from_raw_parts(input, BLOCK_LEN) };
constant_time::xor_assign_at_start(&mut buf, input);
let output = unsafe { slice::from_raw_parts_mut(output, BLOCK_LEN) };
sliceutil::overwrite_at_start(output, &buf);
in_out = in_out
.split_first_chunk::<BLOCK_LEN>(|in_out| {
constant_time::xor_assign_at_start(&mut buf, in_out.input());
sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf);
})
.unwrap_or_else(|IndexError { .. }| {
// Since `in_out_len == in_out.len() && in_out_len >= BLOCK_LEN`.
unreachable!()
});
} else {
let input = unsafe { slice::from_raw_parts(input, in_out_len) };
constant_time::xor_assign_at_start(&mut buf, input);
let output = unsafe { slice::from_raw_parts_mut(output, in_out_len) };
sliceutil::overwrite_at_start(output, &buf);
constant_time::xor_assign_at_start(&mut buf, in_out.input());
sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf);
break;
}

in_out_len -= BLOCK_LEN;
input = unsafe { input.add(BLOCK_LEN) };
output = unsafe { output.add(BLOCK_LEN) };
}
}

Expand Down
72 changes: 72 additions & 0 deletions src/aead/overlapping/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2024 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

#![cfg_attr(not(test), allow(dead_code))]

use super::Overlapping;
use core::array::TryFromSliceError;

pub struct Array<'o, T, const N: usize> {
// Invariant: N != 0.
// Invariant: `self.in_out.len() == N`.
in_out: Overlapping<'o, T>,
}

impl<'o, T, const N: usize> Array<'o, T, N> {
pub(super) fn new(in_out: Overlapping<'o, T>) -> Result<Self, LenMismatchError> {
if N == 0 || in_out.len() != N {
return Err(LenMismatchError::new(N));
}
Ok(Self { in_out })
}

pub fn into_unwritten_output(self) -> &'o mut [T; N]
where
&'o mut [T]: TryInto<&'o mut [T; N], Error = TryFromSliceError>,
{
self.in_out
.into_unwritten_output()
.try_into()
.unwrap_or_else(|TryFromSliceError { .. }| {
unreachable!() // Due to invariant
})
}
}

impl<T, const N: usize> Array<'_, T, N> {
pub fn input<'s>(&'s self) -> &'s [T; N]
where
&'s [T]: TryInto<&'s [T; N], Error = TryFromSliceError>,
{
self.in_out
.input()
.try_into()
.unwrap_or_else(|TryFromSliceError { .. }| {
unreachable!() // Due to invariant
})
}
}

pub struct LenMismatchError {
#[allow(dead_code)]
len: usize,
}

impl LenMismatchError {
#[cold]
#[inline(never)]
fn new(len: usize) -> Self {
Self { len }
}
}
57 changes: 51 additions & 6 deletions src/aead/overlapping/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use core::ops::RangeFrom;
use super::{Array, LenMismatchError};
use core::{mem, ops::RangeFrom};

pub struct Overlapping<'o, T> {
// Invariant: self.src.start <= in_out.len().
Expand All @@ -28,7 +29,7 @@ impl<'o, T> Overlapping<'o, T> {
pub fn new(in_out: &'o mut [T], src: RangeFrom<usize>) -> Result<Self, IndexError> {
match in_out.get(src.clone()) {
Some(_) => Ok(Self { in_out, src }),
None => Err(IndexError::new(src)),
None => Err(IndexError::new(src.start)),
}
}

Expand All @@ -51,7 +52,7 @@ impl<'o, T> Overlapping<'o, T> {
(self.in_out, self.src)
}

pub(super) fn into_unwritten_output(self) -> &'o mut [T] {
pub fn into_unwritten_output(self) -> &'o mut [T] {
let len = self.len();
self.in_out.get_mut(..len).unwrap_or_else(|| {
// The invariant ensures this succeeds.
Expand Down Expand Up @@ -83,14 +84,58 @@ impl<T> Overlapping<'_, T> {
let input = unsafe { output_const.add(self.src.start) };
(input, output, len)
}

// Perhaps unlike `slice::split_first_chunk_mut`, this is biased,
// performance-wise, against the case where `N > self.len()`, so callers
// should be structured to avoid that.
//
// If the result is `Err` then nothing was written to `self`; if anything
// was written then the result will not be `Err`.
#[cfg_attr(not(test), allow(dead_code))]
pub fn split_first_chunk<const N: usize>(
mut self,
f: impl for<'a> FnOnce(Array<'a, T, N>),
) -> Result<Self, IndexError> {
let src = self.src.clone();
let end = self
.src
.start
.checked_add(N)
.ok_or_else(|| IndexError::new(N))?;
let first = self
.in_out
.get_mut(..end)
.ok_or_else(|| IndexError::new(N))?;
let first = Overlapping::new(first, src).unwrap_or_else(|IndexError { .. }| {
// Since `end == src.start + N`.
unreachable!()
});
let first = Array::new(first).unwrap_or_else(|LenMismatchError { .. }| {
// Since `end == src.start + N`.
unreachable!()
});
// Once we call `f`, we must return `Ok` because `f` may have written
// over (part of) the input.
Ok({
f(first);
let tail = mem::take(&mut self.in_out).get_mut(N..).unwrap_or_else(|| {
// There are at least `N` elements since `end == src.start + N`.
unreachable!()
});
Self::new(tail, self.src).unwrap_or_else(|IndexError { .. }| {
// Follows from `end == src.start + N`.
unreachable!()
})
})
}
}

pub struct IndexError(#[allow(dead_code)] RangeFrom<usize>);
pub struct IndexError(#[allow(dead_code)] usize);

impl IndexError {
#[cold]
#[inline(never)]
fn new(src: RangeFrom<usize>) -> Self {
Self(src)
fn new(index: usize) -> Self {
Self(index)
}
}
4 changes: 4 additions & 0 deletions src/aead/overlapping/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub use self::{
array::Array,
base::{IndexError, Overlapping},
partial_block::PartialBlock,
};

use self::array::LenMismatchError;

mod array;
mod base;
mod partial_block;

0 comments on commit d2e401f

Please sign in to comment.