Skip to content

Commit

Permalink
refactor(test): improve some more tests using improved OZK framework
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Jan 19, 2024
1 parent 3446d2c commit f87e37a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 111 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(clippy::needless_borrow)]

use triton_vm::BFieldElement;

use crate::tests_and_benchmarks::ozk::rust_shadows as tasm;
Expand Down Expand Up @@ -58,9 +56,10 @@ fn main() {

#[cfg(test)]
mod test {

use itertools::Itertools;
use rand::random;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest_arbitrary_interop::arb;
use test_strategy::proptest;
use triton_vm::twenty_first::shared_math::bfield_codec::BFieldCodec;
use triton_vm::BFieldElement;
use triton_vm::NonDeterminism;
Expand All @@ -71,48 +70,24 @@ mod test {

use super::*;

#[test]
fn dazefield_element_test() {
// Test function on host machine
let non_determinism = NonDeterminism::new(vec![]);

for _ in 0..4 {
let a: BFieldElement = random();
let b: BFieldElement = random();
let res = a * b;
let stdin: Vec<BFieldElement> = vec![a, b];
let expected_output = [vec![res], res.value().encode()].concat();
let native_output =
rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone());
assert_eq!(native_output, expected_output);

// Test function in Triton VM
let entrypoint_location = ozk_parsing::EntrypointLocation::disk(
"arithmetic",
"dazefield_element_mul",
"main",
);
let test_program = ozk_parsing::compile_for_test(
&entrypoint_location,
crate::ast_types::ListType::Unsafe,
);
let expected_stack_diff = 0;
let vm_output = execute_compiled_with_stack_and_ins_for_test(
&test_program,
vec![],
stdin,
NonDeterminism::new(vec![]),
expected_stack_diff,
)
#[proptest(cases = 20)]
fn dazefield_element_test(#[strategy(vec(arb(), 2))] std_in: Vec<BFieldElement>) {
let native_output =
rust_shadows::wrap_main_with_io(&main)(std_in.clone(), NonDeterminism::default());

let res = std_in[0] * std_in[1];
let expected_output = [vec![res], res.value().encode()].concat();
assert_eq!(native_output, expected_output);

let entrypoint_location =
ozk_parsing::EntrypointLocation::disk("arithmetic", "dazefield_element_mul", "main");
let vm_output = TritonVMTestCase::new(entrypoint_location)
.with_std_in(std_in)
.expect_stack_difference(0)
.execute()
.unwrap();
if expected_output != vm_output.output {
panic!(
"expected:\n{}\n\ngot:\n{}",
expected_output.iter().join(","),
vm_output.output.iter().join(",")
);
}
}

prop_assert_eq!(expected_output, vm_output.output);
}
}

Expand Down
24 changes: 7 additions & 17 deletions src/tests_and_benchmarks/ozk/programs/arithmetic/montyred.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ mod test {

#[test]
fn montyred_test() {
// Test function on host machine
let stdin = vec![];
let non_determinism = NonDeterminism::new(vec![]);
let non_determinism = NonDeterminism::default();
let expected_output = [
BFieldElement::montyred(1u128 << 90).encode(),
BFieldElement::montyred(1).encode(),
Expand All @@ -81,25 +80,16 @@ mod test {
BFieldElement::montyred(2u128 * 0xFFFFFFFE00000001u128).encode(),
]
.concat();
let native_output =
rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone());
let native_output = rust_shadows::wrap_main_with_io(&main)(stdin, non_determinism);
assert_eq!(native_output, expected_output);

// Test function in Triton VM
// Run test on Triton-VM
let entrypoint_location =
ozk_parsing::EntrypointLocation::disk("arithmetic", "montyred", "main");
let test_program =
ozk_parsing::compile_for_test(&entrypoint_location, crate::ast_types::ListType::Unsafe);
let expected_stack_diff = 0;
let vm_output = execute_compiled_with_stack_and_ins_for_test(
&test_program,
vec![],
vec![],
NonDeterminism::new(vec![]),
expected_stack_diff,
)
.unwrap();
let vm_output = TritonVMTestCase::new(entrypoint_location)
.expect_stack_difference(0)
.execute()
.unwrap();

if expected_output != vm_output.output {
panic!(
"expected:\n{}\n\ngot:\n{}",
Expand Down
70 changes: 22 additions & 48 deletions src/tests_and_benchmarks/ozk/programs/arrays/xfe_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,79 +55,53 @@ fn main() {

#[cfg(test)]
mod test {

use itertools::Itertools;
use rand::random;
use triton_vm::twenty_first::shared_math::bfield_codec::BFieldCodec;
use triton_vm::BFieldElement;
use triton_vm::NonDeterminism;

use crate::tests_and_benchmarks::ozk::ozk_parsing;
use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation;
use crate::tests_and_benchmarks::ozk::rust_shadows;
use crate::tests_and_benchmarks::test_helpers::shared_test::execute_compiled_with_stack_and_ins_for_test;
use crate::tests_and_benchmarks::test_helpers::shared_test::TritonVMTestCase;

use proptest_arbitrary_interop::arb;
use test_strategy::proptest;

use super::*;

#[test]
fn xfe_array_test() {
#[proptest(cases = 20)]
fn xfe_array_test(#[strategy(arb())] xfes: [XFieldElement; 4]) {
let std_in = xfes
.map(|xfe| xfe.encode().into_iter().rev())
.into_iter()
.flatten()
.collect_vec();
let non_determinism = NonDeterminism::default();

let xfes: [XFieldElement; 4] = random();
let stdin = {
let mut ret = vec![];
for elem in xfes {
let mut elem = elem.encode();
elem.reverse();
ret.append(&mut elem);
}

ret
};
let native_output = rust_shadows::wrap_main_with_io(&main)(std_in.clone(), non_determinism);
println!("native_output: {native_output:#?}");

let expected_output = vec![
xfes[3].encode(),
xfes[0].encode(),
XFieldElement::new([
BFieldElement::new(52),
BFieldElement::new(53),
BFieldElement::new(54),
])
.encode(),
XFieldElement::new([
BFieldElement::new(52),
BFieldElement::new(53),
BFieldElement::new(54),
])
.encode(),
XFieldElement::new([52, 53, 54].map(BFieldElement::new)).encode(),
XFieldElement::new([52, 53, 54].map(BFieldElement::new)).encode(),
xfes[1].encode(),
// a, b, c, l, k
vec![BFieldElement::new(100)],
vec![BFieldElement::new(200)],
vec![BFieldElement::new(400)],
[100, 200, 400].map(BFieldElement::new).to_vec(),
xfes[2].encode(),
vec![BFieldElement::new(1337)],
]
.concat();

// Run test on host machine
let native_output =
rust_shadows::wrap_main_with_io(&main)(stdin.to_vec(), non_determinism.clone());
println!("native_output: {native_output:#?}");
assert_eq!(native_output, expected_output);

// Run test on Triton-VM
let entrypoint_location = EntrypointLocation::disk("arrays", "xfe_array", "main");
let test_program =
ozk_parsing::compile_for_test(&entrypoint_location, crate::ast_types::ListType::Unsafe);
let vm_output = execute_compiled_with_stack_and_ins_for_test(
&test_program,
vec![],
stdin,
non_determinism,
0,
)
.unwrap();
let vm_output = TritonVMTestCase::new(entrypoint_location)
.with_std_in(std_in)
.expect_stack_difference(0)
.execute()
.unwrap();

assert_eq!(expected_output, vm_output.output);
println!("vm_output.output: {:#?}", vm_output.output);

Expand Down

0 comments on commit f87e37a

Please sign in to comment.