forked from GreptimeTeam/greptimedb
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: benchmark some python script (GreptimeTeam#1356)
* test: bench rspy&pyo3 * docs: add TODO * api heavy * Update src/script/benches/py_benchmark.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * style: toml fmt * test: use `rayon` for threadpool * test: compile first, run later --------- Co-authored-by: Ruihang Xia <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
Showing
3 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
// Copyright 2023 Greptime Team | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
|
||
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider}; | ||
use catalog::{CatalogList, CatalogProvider, SchemaProvider}; | ||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; | ||
use common_query::Output; | ||
use criterion::{black_box, criterion_group, criterion_main, Criterion}; | ||
use futures::Future; | ||
use once_cell::sync::{Lazy, OnceCell}; | ||
use query::QueryEngineFactory; | ||
use rayon::ThreadPool; | ||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; | ||
use script::python::{PyEngine, PyScript}; | ||
use table::table::numbers::NumbersTable; | ||
use tokio::runtime::Runtime; | ||
|
||
static SCRIPT_ENGINE: Lazy<PyEngine> = Lazy::new(sample_script_engine); | ||
static LOCAL_RUNTIME: OnceCell<tokio::runtime::Runtime> = OnceCell::new(); | ||
fn get_local_runtime() -> std::thread::Result<&'static Runtime> { | ||
let rt = LOCAL_RUNTIME | ||
.get_or_try_init(|| tokio::runtime::Runtime::new().map_err(|e| Box::new(e) as _))?; | ||
Ok(rt) | ||
} | ||
/// a terrible hack to call async from sync by: | ||
/// TODO(discord9): find a better way | ||
/// 1. spawn a new thread | ||
/// 2. create a new runtime in new thread and call `block_on` on it | ||
pub fn block_on_async<T, F>(f: F) -> std::thread::Result<T> | ||
where | ||
F: Future<Output = T> + Send + 'static, | ||
T: Send + 'static, | ||
{ | ||
let rt = get_local_runtime()?; | ||
|
||
std::thread::spawn(move || rt.block_on(f)).join() | ||
} | ||
|
||
pub(crate) fn sample_script_engine() -> PyEngine { | ||
let catalog_list = catalog::local::new_memory_catalog_list().unwrap(); | ||
|
||
let default_schema = Arc::new(MemorySchemaProvider::new()); | ||
default_schema | ||
.register_table("numbers".to_string(), Arc::new(NumbersTable::default())) | ||
.unwrap(); | ||
let default_catalog = Arc::new(MemoryCatalogProvider::new()); | ||
default_catalog | ||
.register_schema(DEFAULT_SCHEMA_NAME.to_string(), default_schema) | ||
.unwrap(); | ||
catalog_list | ||
.register_catalog(DEFAULT_CATALOG_NAME.to_string(), default_catalog) | ||
.unwrap(); | ||
|
||
let factory = QueryEngineFactory::new(catalog_list); | ||
let query_engine = factory.query_engine(); | ||
|
||
PyEngine::new(query_engine.clone()) | ||
} | ||
|
||
async fn compile_script(script: &str) -> PyScript { | ||
SCRIPT_ENGINE | ||
.compile(script, CompileContext::default()) | ||
.await | ||
.unwrap() | ||
} | ||
async fn run_compiled(script: &PyScript) { | ||
let output = script | ||
.execute(HashMap::default(), EvalContext::default()) | ||
.await | ||
.unwrap(); | ||
let _res = match output { | ||
Output::Stream(s) => common_recordbatch::util::collect_batches(s).await.unwrap(), | ||
Output::RecordBatches(rbs) => rbs, | ||
_ => unreachable!(), | ||
}; | ||
} | ||
|
||
async fn fibonacci(n: u64, backend: &str) { | ||
let source = format!( | ||
r#" | ||
@copr(returns=["value"], backend="{backend}") | ||
def entry() -> vector[i64]: | ||
def fibonacci(n): | ||
if n <2: | ||
return 1 | ||
else: | ||
return fibonacci(n-1) + fibonacci(n-2) | ||
return fibonacci({n}) | ||
"# | ||
); | ||
let compiled = compile_script(&source).await; | ||
for _ in 0..10 { | ||
run_compiled(&compiled).await; | ||
} | ||
} | ||
|
||
/// TODO(discord9): use a better way to benchmark in parallel | ||
async fn parallel_fibonacci(n: u64, backend: &str, pool: &ThreadPool) { | ||
let source = format!( | ||
r#" | ||
@copr(returns=["value"], backend="{backend}") | ||
def entry() -> vector[i64]: | ||
def fibonacci(n): | ||
if n <2: | ||
return 1 | ||
else: | ||
return fibonacci(n-1) + fibonacci(n-2) | ||
return fibonacci({n}) | ||
"# | ||
); | ||
let source = Arc::new(source); | ||
// execute the script in parallel for every thread in the pool | ||
pool.broadcast(|_| { | ||
let source = source.clone(); | ||
let rt = get_local_runtime().unwrap(); | ||
rt.block_on(async move { | ||
let compiled = compile_script(&source).await; | ||
for _ in 0..10 { | ||
run_compiled(&compiled).await; | ||
} | ||
}); | ||
}); | ||
} | ||
|
||
async fn loop_1_million(backend: &str) { | ||
let source = format!( | ||
r#" | ||
@copr(returns=["value"], backend="{backend}") | ||
def entry() -> vector[i64]: | ||
for i in range(1000000): | ||
pass | ||
return 1 | ||
"# | ||
); | ||
let compiled = compile_script(&source).await; | ||
for _ in 0..10 { | ||
run_compiled(&compiled).await; | ||
} | ||
} | ||
|
||
async fn api_heavy(backend: &str) { | ||
let source = format!( | ||
r#" | ||
from greptime import vector | ||
@copr(args=["number"], sql="select number from numbers", returns=["value"], backend="{backend}") | ||
def entry(number) -> vector[i64]: | ||
for i in range(1000): | ||
n2 = number + number | ||
n_mul = n2 * n2 | ||
n_mask = n_mul[n_mul>2] | ||
return 1 | ||
"# | ||
); | ||
let compiled = compile_script(&source).await; | ||
for _ in 0..10 { | ||
run_compiled(&compiled).await; | ||
} | ||
} | ||
|
||
fn criterion_benchmark(c: &mut Criterion) { | ||
// TODO(discord9): Prime Number, | ||
// and database-local computation/remote download python script comparison | ||
// which require a local mock library | ||
// TODO(discord9): revisit once mock library is ready | ||
|
||
c.bench_function("fib 20 rspy", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| fibonacci(black_box(20), "rspy")) | ||
}); | ||
c.bench_function("fib 20 pyo3", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| fibonacci(black_box(20), "pyo3")) | ||
}); | ||
|
||
let pool = rayon::ThreadPoolBuilder::new() | ||
.num_threads(16) | ||
.build() | ||
.unwrap(); | ||
c.bench_function("par fib 20 rspy", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| parallel_fibonacci(black_box(20), "rspy", &pool)) | ||
}); | ||
c.bench_function("par fib 20 pyo3", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| parallel_fibonacci(black_box(20), "pyo3", &pool)) | ||
}); | ||
|
||
c.bench_function("loop 1M rspy", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| loop_1_million(black_box("rspy"))) | ||
}); | ||
c.bench_function("loop 1M pyo3", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| loop_1_million(black_box("pyo3"))) | ||
}); | ||
c.bench_function("api heavy rspy", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| api_heavy(black_box("rspy"))) | ||
}); | ||
c.bench_function("api heavy pyo3", |b| { | ||
b.to_async(tokio::runtime::Runtime::new().unwrap()) | ||
.iter(|| api_heavy(black_box("pyo3"))) | ||
}); | ||
} | ||
|
||
criterion_group!(benches, criterion_benchmark); | ||
criterion_main!(benches); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters