Skip to content

Commit

Permalink
fix data race in join auto spill (#8145)
Browse files Browse the repository at this point in the history
close #8144
  • Loading branch information
windtalker authored Sep 26, 2023
1 parent 803b58e commit 7bcd506
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
4 changes: 2 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeExec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ HashJoinProbeExecPtr HashJoinProbeExec::tryGetRestoreExec()
HashJoinProbeExecPtr HashJoinProbeExec::doTryGetRestoreExec()
{
/// first check if current join has a partition to restore
if (join->isSpilled() && join->hasPartitionSpilledWithLock())
if (join->isSpilled() && join->hasPartitionToRestore())
{
/// get a restore join
if (auto restore_info = join->getOneRestoreStream(max_block_size); restore_info)
Expand All @@ -183,7 +183,7 @@ HashJoinProbeExecPtr HashJoinProbeExec::doTryGetRestoreExec()
restore_probe_exec->setCancellationHook(is_cancelled);
return restore_probe_exec;
}
assert(join->hasPartitionSpilledWithLock() == false);
assert(join->hasPartitionToRestore() == false);
}
return {};
}
Expand Down
9 changes: 9 additions & 0 deletions dbms/src/Interpreters/HashJoinSpillContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ void HashJoinSpillContext::finishBuild()
in_build_stage = false;
}

size_t HashJoinSpillContext::spilledPartitionCount()
{
size_t ret = 0;
for (auto & is_spilled : (*partition_is_spilled))
if (is_spilled)
++ret;
return ret;
}

bool HashJoinSpillContext::markPartitionForAutoSpill(size_t partition_id)
{
auto old_value = AutoSpillStatus::NO_NEED_AUTO_SPILL;
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Interpreters/HashJoinSpillContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class HashJoinSpillContext final : public OperatorSpillContext
/// only used in random failpoint
bool markPartitionForAutoSpill(size_t partition_id);
void finishBuild();
size_t spilledPartitionCount();
};

using HashJoinSpillContextPtr = std::shared_ptr<HashJoinSpillContext>;
Expand Down
38 changes: 18 additions & 20 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ void Join::checkAndMarkPartitionSpilledIfNeededInternal(
/// first spill
hash_join_spill_context->markPartitionSpilled(partition_index);
join_partition.releasePartitionPoolAndHashMap(partition_lock);
spilled_partition_indexes.push_back(partition_index);
}
auto blocks_to_spill = join_partition.trySpillBuildPartition(partition_lock);
markBuildSideSpillData(partition_index, std::move(blocks_to_spill), stream_index);
Expand Down Expand Up @@ -1786,11 +1785,16 @@ void Join::workAfterBuildFinish(size_t stream_index)
auto partition_lock = join_partition->lockPartition();
hash_join_spill_context->markPartitionSpilled(i);
join_partition->releasePartitionPoolAndHashMap(partition_lock);
spilled_partition_indexes.push_back(i);
}
markBuildSideSpillData(i, partitions[i]->trySpillBuildPartition(), stream_index);
}
}

for (size_t i = 0; i < partitions.size(); ++i)
{
if (hash_join_spill_context->isPartitionSpilled(i))
remaining_partition_indexes_to_restore.push_back(i);
}
LOG_DEBUG(log, "memory usage after build finish: {}", getTotalByteCount());

has_build_data_in_memory = std::any_of(partitions.cbegin(), partitions.cend(), [](const auto & p) {
Expand Down Expand Up @@ -1880,11 +1884,11 @@ void Join::workAfterProbeFinish(size_t stream_index)
if (isEnableSpill())
{
// flush cached blocks for spilled partition.
for (auto spilled_partition_index : spilled_partition_indexes)
markProbeSideSpillData(
spilled_partition_index,
partitions[spilled_partition_index]->trySpillProbePartition(),
stream_index);
for (size_t i = 0; i < partitions.size(); ++i)
{
if (hash_join_spill_context->isPartitionSpilled(i))
markProbeSideSpillData(i, partitions[i]->trySpillProbePartition(), stream_index);
}
hash_join_spill_context->finishSpillableStage();
}

Expand Down Expand Up @@ -2149,7 +2153,7 @@ void Join::spillMostMemoryUsedPartitionIfNeed(size_t stream_index)
#ifdef DBMS_PUBLIC_GTEST
// for join spill to disk gtest
if (restore_round == std::max(2, MAX_RESTORE_ROUND_IN_GTEST) - 1
&& spilled_partition_indexes.size() >= partitions.size() / 2)
&& hash_join_spill_context->spilledPartitionCount() >= partitions.size() / 2)
return;
#endif

Expand All @@ -2171,7 +2175,6 @@ void Join::spillMostMemoryUsedPartitionIfNeed(size_t stream_index)
hash_join_spill_context->markPartitionSpilled(partition_to_be_spilled);
partitions[partition_to_be_spilled]->releasePartitionPoolAndHashMap(partition_lock);
auto blocks_to_spill = partitions[partition_to_be_spilled]->trySpillBuildPartition(partition_lock);
spilled_partition_indexes.push_back(partition_to_be_spilled);
markBuildSideSpillData(partition_to_be_spilled, std::move(blocks_to_spill), stream_index);
}
}
Expand All @@ -2183,15 +2186,10 @@ bool Join::getPartitionSpilled(size_t partition_index) const
}


bool Join::hasPartitionSpilledWithLock()
bool Join::hasPartitionToRestore()
{
std::unique_lock lk(build_probe_mutex);
return hasPartitionSpilled();
}

bool Join::hasPartitionSpilled()
{
return !spilled_partition_indexes.empty();
return !remaining_partition_indexes_to_restore.empty();
}

std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
Expand All @@ -2210,25 +2208,25 @@ std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
restore_infos.pop_back();
if (restore_infos.empty())
{
spilled_partition_indexes.pop_front();
remaining_partition_indexes_to_restore.pop_front();
}
return restore_info;
}
if (spilled_partition_indexes.empty())
if (remaining_partition_indexes_to_restore.empty())
{
return {};
}

// build new restore infos.
auto spilled_partition_index = spilled_partition_indexes.front();
auto spilled_partition_index = remaining_partition_indexes_to_restore.front();
RUNTIME_CHECK_MSG(
hash_join_spill_context->isPartitionSpilled(spilled_partition_index),
"should not restore unspilled partition.");

if (restore_join_build_concurrency <= 0)
restore_join_build_concurrency = getRestoreJoinBuildConcurrency(
partitions.size(),
spilled_partition_indexes.size(),
remaining_partition_indexes_to_restore.size(),
join_restore_concurrency,
probe_concurrency);
/// for restore join we make sure that the restore_join_build_concurrency is at least 2, so it can be spill again.
Expand Down
6 changes: 2 additions & 4 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ class Join

bool getPartitionSpilled(size_t partition_index) const;

bool hasPartitionSpilledWithLock();

bool hasPartitionSpilled();
bool hasPartitionToRestore();

bool isSpilled() const { return hash_join_spill_context->isSpilled(); }

Expand Down Expand Up @@ -371,7 +369,7 @@ class Join

JoinPartitions partitions;

std::list<size_t> spilled_partition_indexes;
std::list<size_t> remaining_partition_indexes_to_restore;

Int64 join_restore_concurrency;

Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Operators/HashProbeTransformExec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ HashProbeTransformExecPtr HashProbeTransformExec::tryGetRestoreExec()
return {};

// first check if current join has a partition to restore
if (join->isSpilled() && join->hasPartitionSpilledWithLock())
if (join->isSpilled() && join->hasPartitionToRestore())
{
// get a restore join
if (auto restore_info = join->getOneRestoreStream(max_block_size); restore_info)
Expand Down Expand Up @@ -86,7 +86,7 @@ HashProbeTransformExecPtr HashProbeTransformExec::tryGetRestoreExec()

return restore_probe_exec;
}
assert(join->hasPartitionSpilledWithLock() == false);
assert(join->hasPartitionToRestore() == false);
}

// current join has no more partition to restore, so check if previous join still has partition to restore
Expand Down

0 comments on commit 7bcd506

Please sign in to comment.