Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap MastForest in Program and Library in Arc #1465

Merged
merged 9 commits into from
Aug 27, 2024
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 0.11.0 (TBD)

#### Changes

- [BREAKING] Wrapped `MastForest`s in `Program` and `Library` structs in `Arc` (#1465).


## 0.10.5 (2024-08-21)

#### Enhancements
Expand Down
6 changes: 3 additions & 3 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ impl Assembler {
})
}
})?;
proc_ctx.register_external_call(&proc, false)?;
proc_ctx.register_external_call(proc, false)?;
},
Some(proc) => proc_ctx.register_external_call(&proc, false)?,
Some(proc) => proc_ctx.register_external_call(proc, false)?,
None => (),
}

Expand Down Expand Up @@ -169,7 +169,7 @@ impl Assembler {
// with the referenced procedure later

if let Some(proc) = mast_forest_builder.find_procedure(&mast_root) {
proc_ctx.register_external_call(&proc, false)?;
proc_ctx.register_external_call(proc, false)?;
}

// Create an array with `Push` operations containing root elements
Expand Down
82 changes: 41 additions & 41 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
vec::Vec,
};
use core::ops::Index;

use vm_core::{
crypto::hash::RpoDigest,
Expand All @@ -24,14 +22,35 @@ const PROCEDURE_INLINING_THRESHOLD: usize = 32;
// ================================================================================================

/// Builder for a [`MastForest`].
///
/// The purpose of the builder is to ensure that the underlying MAST forest contains as little
/// information as possible needed to adequately describe the logical MAST forest. Specifically:
/// - The builder ensures that only one copy of a given node exists in the MAST forest (i.e., no two
/// nodes have the same hash).
/// - The builder tries to merge adjacent basic blocks and eliminate the source block whenever this
/// does not have an impact on other nodes in the forest.
#[derive(Clone, Debug, Default)]
pub struct MastForestBuilder {
/// The MAST forest being built by this builder; this MAST forest is up-to-date - i.e., all
/// nodes added to the MAST forest builder are also immediately added to the underlying MAST
/// forest.
mast_forest: MastForest,
/// A map of MAST node digests to their corresponding positions in the MAST forest. It is
/// guaranteed that a given digests maps to exactly one node in the MAST forest.
node_id_by_hash: BTreeMap<RpoDigest, MastNodeId>,
procedures: BTreeMap<GlobalProcedureIndex, Arc<Procedure>>,
procedure_hashes: BTreeMap<GlobalProcedureIndex, RpoDigest>,
/// A map of all procedures added to the MAST forest indexed by their global procedure ID.
/// This includes all local, exported, and re-exported procedures. In case multiple procedures
/// with the same digest are added to the MAST forest builder, only the first procedure is
/// added to the map, and all subsequent insertions are ignored.
procedures: BTreeMap<GlobalProcedureIndex, Procedure>,
/// A map from procedure MAST root to its global procedure index. Similar to the `procedures`
/// map, this map contains only the first inserted procedure for procedures with the same MAST
/// root.
proc_gid_by_hash: BTreeMap<RpoDigest, GlobalProcedureIndex>,
merged_node_ids: BTreeSet<MastNodeId>,
/// A set of IDs for basic blocks which have been merged into a bigger basic blocks. This is
/// used as a candidate set of nodes that may be eliminated if the are not referenced by any
/// other node in the forest and are not a root of any procedure.
merged_basic_block_ids: BTreeSet<MastNodeId>,
}

impl MastForestBuilder {
Expand All @@ -42,7 +61,7 @@ impl MastForestBuilder {
/// unchanged. Any [`MastNodeId`] used in reference to the old [`MastForest`] should be remapped
/// using this map.
pub fn build(mut self) -> (MastForest, Option<BTreeMap<MastNodeId, MastNodeId>>) {
let nodes_to_remove = get_nodes_to_remove(self.merged_node_ids, &self.mast_forest);
let nodes_to_remove = get_nodes_to_remove(self.merged_basic_block_ids, &self.mast_forest);
let id_remappings = self.mast_forest.remove_nodes(&nodes_to_remove);

(self.mast_forest, id_remappings)
Expand Down Expand Up @@ -109,21 +128,21 @@ impl MastForestBuilder {
/// Returns a reference to the procedure with the specified [`GlobalProcedureIndex`], or None
/// if such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<Arc<Procedure>> {
self.procedures.get(&gid).cloned()
pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<&Procedure> {
self.procedures.get(&gid)
}

/// Returns the hash of the procedure with the specified [`GlobalProcedureIndex`], or None if
/// such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn get_procedure_hash(&self, gid: GlobalProcedureIndex) -> Option<RpoDigest> {
self.procedure_hashes.get(&gid).cloned()
self.procedures.get(&gid).map(|proc| proc.mast_root())
}

/// Returns a reference to the procedure with the specified MAST root, or None
/// if such a procedure is not present in this MAST forest builder.
#[inline(always)]
pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option<Arc<Procedure>> {
pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option<&Procedure> {
self.proc_gid_by_hash.get(mast_root).and_then(|gid| self.get_procedure(*gid))
}

Expand All @@ -141,18 +160,9 @@ impl MastForestBuilder {
}
}

// ------------------------------------------------------------------------------------------------
/// Procedure insertion
impl MastForestBuilder {
pub fn insert_procedure_hash(
&mut self,
gid: GlobalProcedureIndex,
proc_hash: RpoDigest,
) -> Result<(), AssemblyError> {
// TODO(plafer): Check if exists
self.procedure_hashes.insert(gid, proc_hash);

Ok(())
}

/// Inserts a procedure into this MAST forest builder.
///
/// If the procedure with the same ID already exists in this forest builder, this will have
Expand Down Expand Up @@ -202,19 +212,17 @@ impl MastForestBuilder {
}
}

self.make_root(procedure.body_node_id());
self.mast_forest.make_root(procedure.body_node_id());
self.proc_gid_by_hash.insert(proc_root, gid);
self.insert_procedure_hash(gid, procedure.mast_root())?;
self.procedures.insert(gid, Arc::new(procedure));
self.procedures.insert(gid, procedure);

Ok(())
}
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
}

// ------------------------------------------------------------------------------------------------
/// Joining nodes
impl MastForestBuilder {
/// Builds a tree of `JOIN` operations to combine the provided MAST node IDs.
pub fn join_nodes(&mut self, node_ids: Vec<MastNodeId>) -> Result<MastNodeId, AssemblyError> {
debug_assert!(!node_ids.is_empty(), "cannot combine empty MAST node id list");
Expand Down Expand Up @@ -254,7 +262,7 @@ impl MastForestBuilder {
let mut contiguous_basic_block_ids: Vec<MastNodeId> = Vec::new();

for mast_node_id in node_ids {
if self[mast_node_id].is_basic_block() {
if self.mast_forest[mast_node_id].is_basic_block() {
contiguous_basic_block_ids.push(mast_node_id);
} else {
merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?);
Expand Down Expand Up @@ -293,7 +301,8 @@ impl MastForestBuilder {
for &basic_block_id in contiguous_basic_block_ids {
// It is safe to unwrap here, since we already checked that all IDs in
// `contiguous_basic_block_ids` are `BasicBlockNode`s
let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone();
let basic_block_node =
self.mast_forest[basic_block_id].get_basic_block().unwrap().clone();

// check if the block should be merged with other blocks
if should_merge(
Expand Down Expand Up @@ -322,7 +331,7 @@ impl MastForestBuilder {
}

// Mark the removed basic blocks as merged
self.merged_node_ids.extend(contiguous_basic_block_ids.iter());
self.merged_basic_block_ids.extend(contiguous_basic_block_ids.iter());

if !operations.is_empty() || !decorators.is_empty() {
let merged_basic_block = self.ensure_block(operations, Some(decorators))?;
Expand Down Expand Up @@ -414,15 +423,6 @@ impl MastForestBuilder {
}
}

impl Index<MastNodeId> for MastForestBuilder {
type Output = MastNode;

#[inline(always)]
fn index(&self, node_id: MastNodeId) -> &Self::Output {
&self.mast_forest[node_id]
}
}

// HELPER FUNCTIONS
// ================================================================================================

Expand Down
42 changes: 26 additions & 16 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl Assembler {

// TODO: show a warning if library exports are empty?
let (mast_forest, _) = mast_forest_builder.build();
Ok(Library::new(mast_forest, exports))
Ok(Library::new(mast_forest.into(), exports)?)
}

/// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel.
Expand Down Expand Up @@ -343,7 +343,7 @@ impl Assembler {
// TODO: show a warning if library exports are empty?

let (mast_forest, _) = mast_forest_builder.build();
let library = Library::new(mast_forest, exports);
let library = Library::new(mast_forest.into(), exports)?;
Ok(library.try_into()?)
}

Expand Down Expand Up @@ -379,21 +379,19 @@ impl Assembler {
// Compile the module graph rooted at the entrypoint
let mut mast_forest_builder = MastForestBuilder::default();
self.compile_subgraph(entrypoint, &mut mast_forest_builder)?;
let entry_procedure = mast_forest_builder
let entry_node_id = mast_forest_builder
.get_procedure(entrypoint)
.expect("compilation succeeded but root not found in cache");
.expect("compilation succeeded but root not found in cache")
.body_node_id();

// in case the node IDs changed, update the entrypoint ID to the new value
let (mast_forest, id_remappings) = mast_forest_builder.build();
let entry_node_id = {
let old_entry_node_id = entry_procedure.body_node_id();

id_remappings
.map(|id_remappings| id_remappings[&old_entry_node_id])
.unwrap_or(old_entry_node_id)
};
let entry_node_id = id_remappings
.map(|id_remappings| id_remappings[&entry_node_id])
.unwrap_or(entry_node_id);

Ok(Program::with_kernel(
mast_forest,
mast_forest.into(),
entry_node_id,
self.module_graph.kernel().clone(),
))
Expand Down Expand Up @@ -473,8 +471,13 @@ impl Assembler {

// Compile this procedure
let procedure = self.compile_procedure(pctx, mast_forest_builder)?;
// TODO: if a re-exported procedure with the same MAST root had been previously
// added to the builder, this will result in unreachable nodes added to the
// MAST forest. This is because while we won't insert a duplicate node for the
// procedure body node itself, all nodes that make up the procedure body would
// be added to the forest.

// Cache the compiled procedure.
// Cache the compiled procedure
self.module_graph.register_mast_root(procedure_gid, procedure.mast_root())?;
mast_forest_builder.insert_procedure(procedure_gid, procedure)?;
},
Expand All @@ -493,15 +496,22 @@ impl Assembler {
)
.with_span(proc_alias.span());

let proc_alias_root = self.resolve_target(
let proc_mast_root = self.resolve_target(
InvokeKind::ProcRef,
&proc_alias.target().into(),
&pctx,
mast_forest_builder,
)?;

// insert external node into the MAST forest for this procedure; if a procedure
// with the same MAST rood had been previously added to the builder, this will
// have no effect
let proc_node_id = mast_forest_builder.ensure_external(proc_mast_root)?;
let procedure = pctx.into_procedure(proc_mast_root, proc_node_id);
bobbinth marked this conversation as resolved.
Show resolved Hide resolved

// Make the MAST root available to all dependents
self.module_graph.register_mast_root(procedure_gid, proc_alias_root)?;
mast_forest_builder.insert_procedure_hash(procedure_gid, proc_alias_root)?;
self.module_graph.register_mast_root(procedure_gid, proc_mast_root)?;
mast_forest_builder.insert_procedure(procedure_gid, procedure)?;
},
}
}
Expand Down
6 changes: 4 additions & 2 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ fn nested_blocks() -> Result<(), Report> {
.join_nodes(vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id])
.unwrap();

let expected_program = Program::new(expected_mast_forest_builder.build().0, combined_node_id);
let mut expected_mast_forest = expected_mast_forest_builder.build().0;
expected_mast_forest.make_root(combined_node_id);
let expected_program = Program::new(expected_mast_forest.into(), combined_node_id);
assert_eq!(expected_program.hash(), program.hash());

// also check that the program has the right number of procedures (which excludes the dummy
Expand Down Expand Up @@ -214,7 +216,7 @@ fn duplicate_nodes() {

expected_mast_forest.make_root(root_id);

let expected_program = Program::new(expected_mast_forest, root_id);
let expected_program = Program::new(expected_mast_forest.into(), root_id);

assert_eq!(program, expected_program);
}
Expand Down
2 changes: 2 additions & 0 deletions assembly/src/library/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ pub enum LibraryError {
InvalidKernelExport { procedure_path: QualifiedProcedureName },
#[error(transparent)]
Kernel(#[from] KernelError),
#[error("invalid export: no procedure root for {procedure_path} procedure")]
NoProcedureRootForExport { procedure_path: QualifiedProcedureName },
}
Loading
Loading