Skip to content

Commit

Permalink
Update NGT submodule and add NgtDistance tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lerouxrgd committed Nov 20, 2023
1 parent 4bca544 commit 41184ce
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
92 changes: 92 additions & 0 deletions src/ngt/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ mod tests {
use tempfile::tempdir;

use super::*;
use crate::ngt::NgtDistance;
use crate::EPSILON;

#[test]
Expand Down Expand Up @@ -700,4 +701,95 @@ mod tests {
dir.close()?;
Ok(())
}

fn test_dist(dist: NgtDistance) -> Result<()> {
// Get a temporary directory to store the index
let dir = tempdir()?;
if cfg!(feature = "shared_mem") {
std::fs::remove_dir(dir.path())?;
}

// Create a new index
let prop = NgtProperties::<f32>::dimension(3)?.distance_type(dist)?;
let mut index = NgtIndex::create(&dir.path(), prop)?;

// Insert two vectors and get their id
let vec1 = vec![1.0, 2.0, 3.0];
let vec2 = vec![4.0, 5.0, 6.0];
let id1 = index.insert(vec1)?;
let _id2 = index.insert(vec2)?;

// Build index
index.build(1)?;
index.persist()?;

// Perform a vector search (with 1 result)
let res = index.search(&[1.1, 2.1, 3.1], 1, EPSILON)?;
assert_eq!(res[0].id, id1);

// Checks that vector is removable from the index
index.remove(id1)?;

Ok(())
}

#[test]
fn test_dist_l1() -> Result<()> {
test_dist(NgtDistance::L1)
}

#[test]
fn test_dist_l2() -> Result<()> {
test_dist(NgtDistance::L2)
}

#[test]
fn test_dist_angle() -> Result<()> {
test_dist(NgtDistance::Angle)
}

#[test]
fn test_dist_hamming() -> Result<()> {
test_dist(NgtDistance::Hamming)
}

#[test]
fn test_dist_cosine() -> Result<()> {
test_dist(NgtDistance::Cosine)
}

#[test]
fn test_dist_normalized_angle() -> Result<()> {
test_dist(NgtDistance::NormalizedAngle)
}

#[test]
fn test_dist_normalized_cosine() -> Result<()> {
test_dist(NgtDistance::NormalizedCosine)
}

#[test]
fn test_dist_jaccard() -> Result<()> {
test_dist(NgtDistance::Jaccard)
}

#[test]
fn test_dist_sparse_jaccard() -> Result<()> {
test_dist(NgtDistance::SparseJaccard)
}

#[test]
fn test_dist_normalized_l2() -> Result<()> {
test_dist(NgtDistance::NormalizedL2)
}

#[test]
fn test_dist_poincare() -> Result<()> {
test_dist(NgtDistance::Poincare)
}

#[test]
fn test_dist_lorentz() -> Result<()> {
test_dist(NgtDistance::Lorentz)
}
}

0 comments on commit 41184ce

Please sign in to comment.