Skip to content

Commit

Permalink
[trie] Push tracking proof_size_limit to trie recorder (#12710)
Browse files Browse the repository at this point in the history
This PR is the first part of
#12701

The PR moves tracking of proof_size_limit from runtime to trie recorder.
There should be no functional change.

This is a more natural place to expose the check for proof_size_limit
and sets the basis for future improvements like
- Potentially moving compute_limit to recorder as well
- Better tracking and checking of limits (required for resharding)
- Potential setup to add better checks for limits (required for reading
and managing buffered receipts)
  • Loading branch information
shreyan-gupta authored Jan 11, 2025
1 parent cbef830 commit 6e62db1
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 82 deletions.
2 changes: 1 addition & 1 deletion chain/chain/src/resharding/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl ReshardingManager {
"Creating child memtrie by retaining nodes in parent memtrie..."
);
let mut mem_tries = mem_tries.write().unwrap();
let mut trie_recorder = TrieRecorder::new();
let mut trie_recorder = TrieRecorder::new(None);
let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder);
let mem_trie_update = mem_tries.update(*parent_chunk_extra.state_root(), mode)?;

Expand Down
14 changes: 10 additions & 4 deletions chain/chain/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ impl RuntimeAdapter for NightshadeRuntime {

let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block.block_hash)?;
let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?;
let runtime_config = self.runtime_config_store.get_config(protocol_version);

let next_epoch_id =
self.epoch_manager.get_next_epoch_id_from_prev_block(&(&prev_block.block_hash))?;
Expand Down Expand Up @@ -650,16 +651,16 @@ impl RuntimeAdapter for NightshadeRuntime {
if ProtocolFeature::StatelessValidation.enabled(next_protocol_version)
|| cfg!(feature = "shadow_chunk_validation")
{
trie = trie.recording_reads_new_recorder();
let proof_size_limit =
runtime_config.witness_config.new_transactions_validation_state_size_soft_limit;
trie = trie.recording_reads_with_proof_size_limit(proof_size_limit);
}
let mut state_update = TrieUpdate::new(trie);

// Total amount of gas burnt for converting transactions towards receipts.
let mut total_gas_burnt = 0;
let mut total_size = 0u64;

let runtime_config = self.runtime_config_store.get_config(protocol_version);

let transactions_gas_limit =
chunk_tx_gas_limit(protocol_version, runtime_config, &prev_block, shard_id, gas_limit);

Expand Down Expand Up @@ -882,7 +883,12 @@ impl RuntimeAdapter for NightshadeRuntime {
if ProtocolFeature::StatelessValidation.enabled(next_protocol_version)
|| cfg!(feature = "shadow_chunk_validation")
{
trie = trie.recording_reads_new_recorder();
let epoch_id =
self.epoch_manager.get_epoch_id_from_prev_block(&block.prev_block_hash)?;
let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?;
let config = self.runtime_config_store.get_config(protocol_version);
let proof_limit = config.witness_config.main_storage_proof_size_soft_limit;
trie = trie.recording_reads_with_proof_size_limit(proof_limit);
}

match self.process_state_update(
Expand Down
16 changes: 15 additions & 1 deletion core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,18 @@ impl Trie {
/// Makes a new trie that has everything the same except that access
/// through that trie accumulates a state proof for all nodes accessed.
pub fn recording_reads_new_recorder(&self) -> Self {
self.recording_reads_with_recorder(RefCell::new(TrieRecorder::new()))
let recorder = RefCell::new(TrieRecorder::new(None));
self.recording_reads_with_recorder(recorder)
}

/// Makes a new trie that has everything the same except that access
/// through that trie accumulates a state proof for all nodes accessed.
/// We also supply a proof size limit to prevent the proof from growing too large.
pub fn recording_reads_with_proof_size_limit(&self, proof_size_limit: usize) -> Self {
let recorder = RefCell::new(TrieRecorder::new(Some(proof_size_limit)));
self.recording_reads_with_recorder(recorder)
}

pub fn recording_reads_with_recorder(&self, recorder: RefCell<TrieRecorder>) -> Self {
let mut trie = Self::new_with_memtries(
self.storage.clone(),
Expand Down Expand Up @@ -766,6 +773,13 @@ impl Trie {
.unwrap_or_default()
}

pub fn check_proof_size_limit_exceed(&self) -> bool {
self.recorder
.as_ref()
.map(|recorder| recorder.borrow().check_proof_size_limit_exceed())
.unwrap_or_default()
}

/// Constructs a Trie from the partial storage (i.e. state proof) that
/// was returned from recorded_storage(). If used to access the same trie
/// nodes as when the partial storage was generated, this trie will behave
Expand Down
2 changes: 1 addition & 1 deletion core/store/src/trie/ops/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fn run(initial_entries: Vec<(Vec<u8>, Vec<u8>)>, retain_multi_ranges: Vec<Range<
retain_split_shard_custom_ranges_for_trie(&trie, &retain_multi_ranges);

// Split memtrie and track proof
let mut trie_recorder = TrieRecorder::new();
let mut trie_recorder = TrieRecorder::new(None);
let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder);
let mut update = memtries.update(initial_state_root, mode).unwrap();
retain_split_shard_custom_ranges(&mut update, &retain_multi_ranges);
Expand Down
41 changes: 26 additions & 15 deletions core/store/src/trie/trie_recording.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ use std::sync::Arc;
pub struct TrieRecorder {
recorded: HashMap<CryptoHash, Arc<[u8]>>,
size: usize,
/// Size of the recorded state proof plus some additional size added to cover removals and contract code.
/// An upper-bound estimation of the true recorded size after finalization.
/// See https://github.com/near/nearcore/issues/10890 and https://github.com/near/nearcore/pull/11000 for details.
upper_bound_size: usize,
/// Soft limit on the maximum size of the state proof that can be recorded.
proof_size_limit: Option<usize>,
/// Counts removals performed while recording.
/// recorded_storage_size_upper_bound takes it into account when calculating the total size.
removal_counter: usize,
Expand Down Expand Up @@ -45,10 +51,12 @@ pub struct SubtreeSize {
}

impl TrieRecorder {
pub fn new() -> Self {
pub fn new(proof_size_limit: Option<usize>) -> Self {
Self {
recorded: HashMap::new(),
size: 0,
upper_bound_size: 0,
proof_size_limit,
removal_counter: 0,
code_len_counter: 0,
codes_to_record: Default::default(),
Expand All @@ -66,16 +74,27 @@ impl TrieRecorder {
pub fn record(&mut self, hash: &CryptoHash, node: Arc<[u8]>) {
let size = node.len();
if self.recorded.insert(*hash, node).is_none() {
self.size += size;
self.size = self.size.checked_add(size).unwrap();
self.upper_bound_size = self.upper_bound_size.checked_add(size).unwrap();
}
}

pub fn record_removal(&mut self) {
self.removal_counter = self.removal_counter.saturating_add(1)
pub fn record_key_removal(&mut self) {
// Charge 2000 bytes for every removal
self.removal_counter = self.removal_counter.checked_add(1).unwrap();
self.upper_bound_size = self.upper_bound_size.checked_add(2000).unwrap();
}

pub fn record_code_len(&mut self, code_len: usize) {
self.code_len_counter = self.code_len_counter.saturating_add(code_len)
self.code_len_counter = self.code_len_counter.checked_add(code_len).unwrap();
self.upper_bound_size = self.upper_bound_size.checked_add(code_len).unwrap();
}

pub fn check_proof_size_limit_exceed(&self) -> bool {
if let Some(proof_size_limit) = self.proof_size_limit {
return self.upper_bound_size > proof_size_limit;
}
false
}

pub fn recorded_storage(&mut self) -> PartialStorage {
Expand All @@ -88,19 +107,11 @@ impl TrieRecorder {
self.size
}

/// Size of the recorded state proof plus some additional size added to cover removals
/// and contract codes.
/// An upper-bound estimation of the true recorded size after finalization.
/// See https://github.com/near/nearcore/issues/10890 and https://github.com/near/nearcore/pull/11000 for details.
pub fn recorded_storage_size_upper_bound(&self) -> usize {
// Charge 2000 bytes for every removal
let removals_size = self.removal_counter.saturating_mul(2000);
self.recorded_storage_size()
.saturating_add(removals_size)
.saturating_add(self.code_len_counter)
self.upper_bound_size
}

/// Get statisitics about the recorded trie. Useful for observability and debugging.
/// Get statistics about the recorded trie. Useful for observability and debugging.
/// This scans all of the recorded data, so could potentially be expensive to run.
pub fn get_stats(&self, trie_root: &CryptoHash) -> TrieRecorderStats {
let mut trie_column_sizes = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion core/store/src/trie/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl TrieUpdate {
// by the runtime are assumed to be non-malicious and we don't charge extra for them.
if let Some(recorder) = &self.trie.recorder {
if matches!(trie_key, TrieKey::ContractData { .. }) {
recorder.borrow_mut().record_removal();
recorder.borrow_mut().record_key_removal();
}
}

Expand Down
85 changes: 28 additions & 57 deletions runtime/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1656,54 +1656,53 @@ impl Runtime {
compute_usage = tracing::field::Empty,
)
.entered();

let state_update = &mut processing_state.state_update;
let node_counter_before = state_update.trie().get_trie_nodes_count();
let recorded_storage_size_before = state_update.trie().recorded_storage_size();
let storage_proof_size_upper_bound_before =
state_update.trie().recorded_storage_size_upper_bound();
let trie = state_update.trie();
let node_counter_before = trie.get_trie_nodes_count();
let recorded_storage_size_before = trie.recorded_storage_size();
let storage_proof_size_upper_bound_before = trie.recorded_storage_size_upper_bound();

// Main logic
let result = self.process_receipt(
processing_state,
receipt,
&mut receipt_sink,
&mut validator_proposals,
);

let total = &mut processing_state.total;
let state_update = &mut processing_state.state_update;
let node_counter_after = state_update.trie().get_trie_nodes_count();
tracing::trace!(target: "runtime", ?node_counter_before, ?node_counter_after);
let recorded_storage_diff = state_update
.trie()
.recorded_storage_size()
.saturating_sub(recorded_storage_size_before)
as f64;
let recorded_storage_upper_bound_diff = state_update
.trie()
.recorded_storage_size_upper_bound()
.saturating_sub(storage_proof_size_upper_bound_before)
as f64;
let shard_id_str = processing_state.apply_state.shard_id.to_string();
let trie = processing_state.state_update.trie();

let node_counter_after = trie.get_trie_nodes_count();
tracing::trace!(target: "runtime", ?node_counter_before, ?node_counter_after);

let recorded_storage_diff = trie.recorded_storage_size() - recorded_storage_size_before;
let recorded_storage_upper_bound_diff =
trie.recorded_storage_size_upper_bound() - storage_proof_size_upper_bound_before;
metrics::RECEIPT_RECORDED_SIZE
.with_label_values(&[shard_id_str.as_str()])
.observe(recorded_storage_diff);
.observe(recorded_storage_diff as f64);
metrics::RECEIPT_RECORDED_SIZE_UPPER_BOUND
.with_label_values(&[shard_id_str.as_str()])
.observe(recorded_storage_upper_bound_diff);
.observe(recorded_storage_upper_bound_diff as f64);
let recorded_storage_proof_ratio =
recorded_storage_upper_bound_diff / f64::max(1.0, recorded_storage_diff);
recorded_storage_upper_bound_diff as f64 / f64::max(1.0, recorded_storage_diff as f64);
// Record the ratio only for large receipts, small receipts can have a very high ratio,
// but the ratio is not that important for them.
if recorded_storage_upper_bound_diff > 100_000. {
if recorded_storage_upper_bound_diff > 100_000 {
metrics::RECEIPT_RECORDED_SIZE_UPPER_BOUND_RATIO
.with_label_values(&[shard_id_str.as_str()])
.observe(recorded_storage_proof_ratio);
}

if let Some(outcome_with_id) = result? {
let gas_burnt = outcome_with_id.outcome.gas_burnt;
let compute_usage = outcome_with_id
.outcome
.compute_usage
.expect("`process_receipt` must populate compute usage");
let total = &mut processing_state.total;
total.add(gas_burnt, compute_usage)?;
span.record("gas_burnt", gas_burnt);
span.record("compute_usage", compute_usage);
Expand All @@ -1726,7 +1725,6 @@ impl Runtime {
mut processing_state: &mut ApplyProcessingReceiptState<'a>,
receipt_sink: &mut ReceiptSink,
compute_limit: u64,
proof_size_limit: Option<usize>,
validator_proposals: &mut Vec<ValidatorStake>,
) -> Result<(), RuntimeError> {
let local_processing_start = std::time::Instant::now();
Expand All @@ -1750,9 +1748,7 @@ impl Runtime {

for receipt in local_receipts.iter() {
if processing_state.total.compute >= compute_limit
|| proof_size_limit.is_some_and(|limit| {
processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit
})
|| processing_state.state_update.trie.check_proof_size_limit_exceed()
{
processing_state.delayed_receipts.push(
&mut processing_state.state_update,
Expand Down Expand Up @@ -1808,7 +1804,6 @@ impl Runtime {
mut processing_state: &mut ApplyProcessingReceiptState<'a>,
receipt_sink: &mut ReceiptSink,
compute_limit: u64,
proof_size_limit: Option<usize>,
validator_proposals: &mut Vec<ValidatorStake>,
) -> Result<Vec<Receipt>, RuntimeError> {
let delayed_processing_start = std::time::Instant::now();
Expand All @@ -1828,9 +1823,7 @@ impl Runtime {

loop {
if processing_state.total.compute >= compute_limit
|| proof_size_limit.is_some_and(|limit| {
processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit
})
|| processing_state.state_update.trie.check_proof_size_limit_exceed()
{
break;
}
Expand Down Expand Up @@ -1910,7 +1903,6 @@ impl Runtime {
mut processing_state: &mut ApplyProcessingReceiptState<'a>,
receipt_sink: &mut ReceiptSink,
compute_limit: u64,
proof_size_limit: Option<usize>,
validator_proposals: &mut Vec<ValidatorStake>,
) -> Result<(), RuntimeError> {
let incoming_processing_start = std::time::Instant::now();
Expand Down Expand Up @@ -1940,9 +1932,7 @@ impl Runtime {
)
.map_err(RuntimeError::ReceiptValidationError)?;
if processing_state.total.compute >= compute_limit
|| proof_size_limit.is_some_and(|limit| {
processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit
})
|| processing_state.state_update.trie.check_proof_size_limit_exceed()
{
processing_state.delayed_receipts.push(
&mut processing_state.state_update,
Expand Down Expand Up @@ -1992,24 +1982,17 @@ impl Runtime {
receipt_sink: &mut ReceiptSink,
) -> Result<ProcessReceiptsResult, RuntimeError> {
let mut validator_proposals = vec![];
let protocol_version = processing_state.protocol_version;
let apply_state = &processing_state.apply_state;

// TODO(#8859): Introduce a dedicated `compute_limit` for the chunk.
// For now compute limit always matches the gas limit.
let compute_limit = apply_state.gas_limit.unwrap_or(Gas::max_value());
let proof_size_limit = if ProtocolFeature::StatelessValidation.enabled(protocol_version) {
Some(apply_state.config.witness_config.main_storage_proof_size_soft_limit)
} else {
None
};

// We first process local receipts. They contain staking, local contract calls, etc.
self.process_local_receipts(
processing_state,
receipt_sink,
compute_limit,
proof_size_limit,
&mut validator_proposals,
)?;

Expand All @@ -2018,7 +2001,6 @@ impl Runtime {
processing_state,
receipt_sink,
compute_limit,
proof_size_limit,
&mut validator_proposals,
)?;

Expand All @@ -2027,26 +2009,19 @@ impl Runtime {
processing_state,
receipt_sink,
compute_limit,
proof_size_limit,
&mut validator_proposals,
)?;

// Resolve timed-out PromiseYield receipts
let promise_yield_result = resolve_promise_yield_timeouts(
processing_state,
receipt_sink,
compute_limit,
proof_size_limit,
)?;
let promise_yield_result =
resolve_promise_yield_timeouts(processing_state, receipt_sink, compute_limit)?;

let shard_id_str = processing_state.apply_state.shard_id.to_string();
if processing_state.total.compute >= compute_limit {
metrics::CHUNK_RECEIPTS_LIMITED_BY
.with_label_values(&[shard_id_str.as_str(), "compute_limit"])
.inc();
} else if proof_size_limit.is_some_and(|limit| {
processing_state.state_update.trie.recorded_storage_size_upper_bound() > limit
}) {
} else if processing_state.state_update.trie.check_proof_size_limit_exceed() {
metrics::CHUNK_RECEIPTS_LIMITED_BY
.with_label_values(&[shard_id_str.as_str(), "storage_proof_size_limit"])
.inc();
Expand Down Expand Up @@ -2351,7 +2326,6 @@ fn resolve_promise_yield_timeouts(
processing_state: &mut ApplyProcessingReceiptState,
receipt_sink: &mut ReceiptSink,
compute_limit: u64,
proof_size_limit: Option<usize>,
) -> Result<ResolvePromiseYieldTimeoutsResult, RuntimeError> {
let mut state_update = &mut processing_state.state_update;
let total = &mut processing_state.total;
Expand All @@ -2366,10 +2340,7 @@ fn resolve_promise_yield_timeouts(
let mut timeout_receipts = vec![];
let yield_processing_start = std::time::Instant::now();
while promise_yield_indices.first_index < promise_yield_indices.next_available_index {
if total.compute >= compute_limit
|| proof_size_limit
.is_some_and(|limit| state_update.trie.recorded_storage_size_upper_bound() > limit)
{
if total.compute >= compute_limit || state_update.trie.check_proof_size_limit_exceed() {
break;
}

Expand Down
Loading

0 comments on commit 6e62db1

Please sign in to comment.