diff --git a/src/core/search/base.cc b/src/core/search/base.cc index b6c4d3f6a611..e4d4121fee00 100644 --- a/src/core/search/base.cc +++ b/src/core/search/base.cc @@ -4,6 +4,8 @@ #include "core/search/base.h" +#include + namespace dfly::search { std::string_view QueryParams::operator[](std::string_view name) const { @@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const { return std::string_view{ptr.get(), std::strlen(ptr.get())}; } +std::optional ParseNumericField(std::string_view value) { + double value_as_double; + if (absl::SimpleAtod(value, &value_as_double)) + return value_as_double; + return std::nullopt; +} + } // namespace dfly::search diff --git a/src/core/search/base.h b/src/core/search/base.h index 964dcfcee6f2..9ff30472fb0d 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -68,11 +68,18 @@ using SortableValue = std::variant; struct DocumentAccessor { using VectorInfo = search::OwnedFtVector; using StringList = absl::InlinedVector; + using NumsList = absl::InlinedVector; virtual ~DocumentAccessor() = default; - virtual StringList GetStrings(std::string_view active_field) const = 0; - virtual VectorInfo GetVector(std::string_view active_field) const = 0; + /* Returns nullopt if the specified field is not a list of strings */ + virtual std::optional GetStrings(std::string_view active_field) const = 0; + + /* Returns nullopt if the specified field is not a vector */ + virtual std::optional GetVector(std::string_view active_field) const = 0; + + /* Return nullopt if the specified field is not a list of doubles */ + virtual std::optional GetNumbers(std::string_view active_field) const = 0; }; // Base class for type-specific indices. @@ -81,8 +88,10 @@ struct DocumentAccessor { // query functions. All results for all index types should be sorted. struct BaseIndex { virtual ~BaseIndex() = default; - virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0; - virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0; + + // Returns true if the document was added / indexed + virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; + virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; }; // Base class for type-specific sorting indices. @@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex { virtual std::vector Sort(std::vector* ids, size_t limit, bool desc) const = 0; }; +/* Used for converting field values to double. Returns std::nullopt if the conversion fails */ +std::optional ParseNumericField(std::string_view value); + +/* Temporary method to create an empty std::optional in DocumentAccessor::GetString + and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation + details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To + suppress this false warning, we temporarily disable it around this block of code using GCC + diagnostic directives. */ +template std::optional EmptyAccessResult() { + // GCC 13.1 throws spurious warnings around this code. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + return InlinedVector{}; +#pragma GCC diagnostic pop +} + } // namespace dfly::search diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index 3f7939d8f649..9d125a5a8a13 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -71,19 +71,22 @@ absl::flat_hash_set NormalizeTags(string_view taglist, bool case_sensiti NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} { } -void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { - for (auto str : doc->GetStrings(field)) { - double num; - if (absl::SimpleAtod(str, &num)) - entries_.emplace(num, id); +bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) { + auto numbers = doc.GetNumbers(field); + if (!numbers) { + return false; } + + for (auto num : numbers.value()) { + entries_.emplace(num, id); + } + return true; } -void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { - for (auto str : doc->GetStrings(field)) { - double num; - if (absl::SimpleAtod(str, &num)) - entries_.erase({num, id}); +void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { + auto numbers = doc.GetNumbers(field).value(); + for (auto num : numbers) { + entries_.erase({num, id}); } } @@ -139,19 +142,27 @@ typename BaseStringIndex::Container* BaseStringIndex::GetOrCreate(string_v } template -void BaseStringIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { +bool BaseStringIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) { + auto strings_list = doc.GetStrings(field); + if (!strings_list) { + return false; + } + absl::flat_hash_set tokens; - for (string_view str : doc->GetStrings(field)) + for (string_view str : strings_list.value()) tokens.merge(Tokenize(str)); for (string_view token : tokens) GetOrCreate(token)->Insert(id); + return true; } template -void BaseStringIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { +void BaseStringIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { + auto strings_list = doc.GetStrings(field).value(); + absl::flat_hash_set tokens; - for (string_view str : doc->GetStrings(field)) + for (string_view str : strings_list) tokens.merge(Tokenize(str)); for (const auto& token : tokens) { @@ -192,6 +203,20 @@ std::pair BaseVectorIndex::Info() const { return {dim_, sim_}; } +bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { + auto vector = doc.GetVector(field); + if (!vector) + return false; + + auto& [ptr, size] = vector.value(); + if (ptr && size != dim_) { + return false; + } + + AddVector(id, ptr); + return true; +} + FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr) : BaseVectorIndex{params.dim, params.sim}, entries_{mr} { @@ -199,19 +224,18 @@ FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, entries_.reserve(params.capacity * params.dim); } -void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { +void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) { DCHECK_LE(id * dim_, entries_.size()); if (id * dim_ == entries_.size()) entries_.resize((id + 1) * dim_); // TODO: Let get vector write to buf itself - auto [ptr, size] = doc->GetVector(field); - - if (size == dim_) - memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float)); + if (vector) { + memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float)); + } } -void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { +void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { // noop } @@ -229,7 +253,7 @@ struct HnswlibAdapter { 100 /* seed*/} { } - void Add(float* data, DocId id) { + void Add(const float* data, DocId id) { if (world_.cur_element_count + 1 >= world_.max_elements_) world_.resizeIndex(world_.cur_element_count * 2); world_.addPoint(data, id); @@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS HnswVectorIndex::~HnswVectorIndex() { } -void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { - auto [ptr, size] = doc->GetVector(field); - if (size == dim_) - adapter_->Add(ptr.get(), id); +void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) { + if (vector) { + adapter_->Add(vector.get(), id); + } } std::vector> HnswVectorIndex::Knn(float* target, size_t k, @@ -314,7 +338,7 @@ std::vector> HnswVectorIndex::Knn(float* target, size_t return adapter_->Knn(target, k, ef, allowed); } -void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { +void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { adapter_->Remove(id); } diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 61c8a6a01853..0058e2043c71 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -28,8 +28,8 @@ namespace dfly::search { struct NumericIndex : public BaseIndex { explicit NumericIndex(PMR_NS::memory_resource* mr); - void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; - void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; + bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; std::vector Range(double l, double r) const; @@ -44,8 +44,8 @@ template struct BaseStringIndex : public BaseIndex { BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive); - void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; - void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; + bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; // Used by Add & Remove to tokenize text value virtual absl::flat_hash_set Tokenize(std::string_view value) const = 0; @@ -53,7 +53,7 @@ template struct BaseStringIndex : public BaseIndex { // Pointer is valid as long as index is not mutated. Nullptr if not found const Container* Matching(std::string_view str) const; - // Iterate over all Machting on prefix. + // Iterate over all Matching on prefix. void MatchingPrefix(std::string_view prefix, absl::FunctionRef cb) const; // Returns all the terms that appear as keys in the reverse index. @@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex { struct BaseVectorIndex : public BaseIndex { std::pair Info() const; + bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final; + protected: BaseVectorIndex(size_t dim, VectorSimilarity sim); + using VectorPtr = decltype(std::declval().first); + virtual void AddVector(DocId id, const VectorPtr& vector) = 0; + size_t dim_; VectorSimilarity sim_; }; @@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex { struct FlatVectorIndex : public BaseVectorIndex { FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); - void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; - void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; const float* Get(DocId doc) const; + protected: + void AddVector(DocId id, const VectorPtr& vector) override; + private: PMR_NS::vector entries_; }; @@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex { HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); ~HnswVectorIndex(); - void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; - void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; std::vector> Knn(float* target, size_t k, std::optional ef) const; std::vector> Knn(float* target, size_t k, std::optional ef, const std::vector& allowed) const; + protected: + void AddVector(DocId id, const VectorPtr& vector) override; + private: std::unique_ptr adapter_; }; diff --git a/src/core/search/search.cc b/src/core/search/search.cc index cd0cc5a8a232..3faca2c4e1db 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) { } } -void FieldIndices::Add(DocId doc, DocumentAccessor* access) { - for (auto& [field, index] : indices_) - index->Add(doc, access, field); - for (auto& [field, sort_index] : sort_indices_) - sort_index->Add(doc, access, field); +bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) { + bool was_added = true; + + std::vector> successfully_added_indices; + successfully_added_indices.reserve(indices_.size() + sort_indices_.size()); + + auto try_add = [&](const auto& indices_container) { + for (auto& [field, index] : indices_container) { + if (index->Add(doc, access, field)) { + successfully_added_indices.emplace_back(field, index.get()); + } else { + was_added = false; + break; + } + } + }; + + try_add(indices_); + + if (was_added) { + try_add(sort_indices_); + } + + if (!was_added) { + for (auto& [field, index] : successfully_added_indices) { + index->Remove(doc, access, field); + } + return false; + } all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc); + return true; } -void FieldIndices::Remove(DocId doc, DocumentAccessor* access) { +void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) { for (auto& [field, index] : indices_) index->Remove(doc, access, field); for (auto& [field, sort_index] : sort_indices_) sort_index->Remove(doc, access, field); auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc); - CHECK(it != all_ids_.end() && *it == doc); + DCHECK(it != all_ids_.end() && *it == doc); all_ids_.erase(it); } diff --git a/src/core/search/search.h b/src/core/search/search.h index c37a67fa7d6e..85798da4779e 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -77,8 +77,9 @@ class FieldIndices { // Create indices based on schema and options. Both must outlive the indices FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr); - void Add(DocId doc, DocumentAccessor* access); - void Remove(DocId doc, DocumentAccessor* access); + // Returns true if document was added + bool Add(DocId doc, const DocumentAccessor& access); + void Remove(DocId doc, const DocumentAccessor& access); BaseIndex* GetIndex(std::string_view field) const; BaseSortIndex* GetSortIndex(std::string_view field) const; diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index 63573f895198..37752ebdb8f0 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -44,13 +44,36 @@ struct MockedDocument : public DocumentAccessor { MockedDocument(std::string test_field) : fields_{{"field", test_field}} { } - StringList GetStrings(string_view field) const override { + std::optional GetStrings(string_view field) const override { auto it = fields_.find(field); - return {it != fields_.end() ? string_view{it->second} : ""}; + if (it == fields_.end()) { + return EmptyAccessResult(); + } + return StringList{string_view{it->second}}; + } + + std::optional GetVector(string_view field) const override { + auto strings_list = GetStrings(field); + if (!strings_list) + return std::nullopt; + return !strings_list->empty() ? BytesToFtVectorSafe(strings_list->front()) : VectorInfo{}; } - VectorInfo GetVector(string_view field) const override { - return BytesToFtVector(GetStrings(field).front()); + std::optional GetNumbers(std::string_view field) const override { + auto strings_list = GetStrings(field); + if (!strings_list) + return std::nullopt; + + NumsList nums_list; + nums_list.reserve(strings_list->size()); + for (auto str : strings_list.value()) { + auto num = ParseNumericField(str); + if (!num) { + return std::nullopt; + } + nums_list.push_back(num.value()); + } + return nums_list; } string DebugFormat() { @@ -121,7 +144,7 @@ class SearchTest : public ::testing::Test { shuffle(entries_.begin(), entries_.end(), default_random_engine{}); for (DocId i = 0; i < entries_.size(); i++) - index.Add(i, &entries_[i].first); + index.Add(i, entries_[i].first); SearchAlgorithm search_algo{}; if (!search_algo.Init(query_, ¶ms_)) { @@ -430,7 +453,7 @@ TEST_F(SearchTest, StopWords) { "explicitly found!"}; for (size_t i = 0; i < documents.size(); i++) { MockedDocument doc{{{"title", documents[i]}}}; - indices.Add(i, &doc); + indices.Add(i, doc); } // words is a stopword @@ -484,7 +507,7 @@ TEST_P(KnnTest, Simple1D) { for (size_t i = 0; i < 100; i++) { Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}}; MockedDocument doc{values}; - indices.Add(i, &doc); + indices.Add(i, doc); } SearchAlgorithm algo{}; @@ -540,7 +563,7 @@ TEST_P(KnnTest, Simple2D) { for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; - indices.Add(i, &doc); + indices.Add(i, doc); } SearchAlgorithm algo{}; @@ -602,7 +625,7 @@ TEST_P(KnnTest, Cosine) { for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; - indices.Add(i, &doc); + indices.Add(i, doc); } SearchAlgorithm algo{}; @@ -646,7 +669,7 @@ TEST_P(KnnTest, AddRemove) { vector documents(10); for (size_t i = 0; i < 10; i++) { documents[i] = Map{{"pos", ToBytes({float(i)})}}; - indices.Add(i, &documents[i]); + indices.Add(i, documents[i]); } SearchAlgorithm algo{}; @@ -661,7 +684,7 @@ TEST_P(KnnTest, AddRemove) { // delete leftmost 5 for (size_t i = 0; i < 5; i++) - indices.Remove(i, &documents[i]); + indices.Remove(i, documents[i]); // search leftmost 5 again { @@ -672,7 +695,7 @@ TEST_P(KnnTest, AddRemove) { // add removed elements for (size_t i = 0; i < 5; i++) - indices.Add(i, &documents[i]); + indices.Add(i, documents[i]); // repeat first search { @@ -693,7 +716,7 @@ TEST_P(KnnTest, AutoResize) { for (size_t i = 0; i < 100; i++) { MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}}; - indices.Add(i, &doc); + indices.Add(i, doc); } EXPECT_EQ(indices.GetAllDocs().size(), 100); @@ -720,7 +743,7 @@ static void BM_VectorSearch(benchmark::State& state) { for (size_t i = 0; i < nvecs; i++) { auto rv = random_vec(); MockedDocument doc{Map{{"pos", ToBytes(rv)}}}; - indices.Add(i, &doc); + indices.Add(i, doc); } SearchAlgorithm algo{}; diff --git a/src/core/search/sort_indices.cc b/src/core/search/sort_indices.cc index ed9e84255f57..2eb2c4aa322d 100644 --- a/src/core/search/sort_indices.cc +++ b/src/core/search/sort_indices.cc @@ -46,15 +46,23 @@ std::vector SimpleValueSortIndex::Sort(std::vector* ids, } template -void SimpleValueSortIndex::Add(DocId id, DocumentAccessor* doc, std::string_view field) { +bool SimpleValueSortIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { + auto field_value = Get(doc, field); + if (!field_value) { + return false; + } + DCHECK_LE(id, values_.size()); // Doc ids grow at most by one if (id >= values_.size()) values_.resize(id + 1); - values_[id] = Get(id, doc, field); + + values_[id] = field_value.value(); + return true; } template -void SimpleValueSortIndex::Remove(DocId id, DocumentAccessor* doc, std::string_view field) { +void SimpleValueSortIndex::Remove(DocId id, const DocumentAccessor& doc, + std::string_view field) { DCHECK_LT(id, values_.size()); values_[id] = T{}; } @@ -66,23 +74,22 @@ template PMR_NS::memory_resource* SimpleValueSortIndex::GetMemRe template struct SimpleValueSortIndex; template struct SimpleValueSortIndex; -double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { - auto str = doc->GetStrings(field); - if (str.empty()) - return 0; - - double v; - if (!absl::SimpleAtod(str.front(), &v)) - return 0; - return v; +std::optional NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) { + auto numbers_list = doc.GetNumbers(field); + if (!numbers_list) { + return std::nullopt; + } + return !numbers_list->empty() ? numbers_list->front() : 0.0; } -PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { - auto str = doc->GetStrings(field); - if (str.empty()) - return ""; - - return PMR_NS::string{str.front(), GetMemRes()}; +std::optional StringSortIndex::Get(const DocumentAccessor& doc, + std::string_view field) { + auto strings_list = doc.GetStrings(field); + if (!strings_list) { + return std::nullopt; + } + return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()} + : PMR_NS::string{GetMemRes()}; } } // namespace dfly::search diff --git a/src/core/search/sort_indices.h b/src/core/search/sort_indices.h index 591839a7738a..bdffc1a0f55d 100644 --- a/src/core/search/sort_indices.h +++ b/src/core/search/sort_indices.h @@ -24,11 +24,11 @@ template struct SimpleValueSortIndex : BaseSortIndex { SortableValue Lookup(DocId doc) const override; std::vector Sort(std::vector* ids, size_t limit, bool desc) const override; - void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; - void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; + bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; + void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; protected: - virtual T Get(DocId id, DocumentAccessor* doc, std::string_view field) = 0; + virtual std::optional Get(const DocumentAccessor& doc, std::string_view field_value) = 0; PMR_NS::memory_resource* GetMemRes() const; @@ -39,14 +39,14 @@ template struct SimpleValueSortIndex : BaseSortIndex { struct NumericSortIndex : public SimpleValueSortIndex { NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; - double Get(DocId id, DocumentAccessor* doc, std::string_view field) override; + std::optional Get(const DocumentAccessor& doc, std::string_view field) override; }; // TODO: Map tags to integers for fast sort struct StringSortIndex : public SimpleValueSortIndex { StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; - PMR_NS::string Get(DocId id, DocumentAccessor* doc, std::string_view field) override; + std::optional Get(const DocumentAccessor& doc, std::string_view field) override; }; } // namespace dfly::search diff --git a/src/core/search/vector_utils.cc b/src/core/search/vector_utils.cc index a9a45911f837..1df311dd8195 100644 --- a/src/core/search/vector_utils.cc +++ b/src/core/search/vector_utils.cc @@ -39,18 +39,28 @@ __attribute__((optimize("fast-math"))) float CosineDistance(const float* u, cons return 0.0f; } -} // namespace - -OwnedFtVector BytesToFtVector(string_view value) { - DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size(); - +OwnedFtVector ConvertToFtVector(string_view value) { // Value cannot be casted directly as it might be not aligned as a float (4 bytes). // Misaligned memory access is UB. size_t size = value.size() / sizeof(float); auto out = make_unique(size); memcpy(out.get(), value.data(), size * sizeof(float)); - return {std::move(out), size}; + return OwnedFtVector{std::move(out), size}; +} + +} // namespace + +OwnedFtVector BytesToFtVector(string_view value) { + DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size(); + return ConvertToFtVector(value); +} + +std::optional BytesToFtVectorSafe(string_view value) { + if (value.size() % sizeof(float)) { + return std::nullopt; + } + return ConvertToFtVector(value); } float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) { diff --git a/src/core/search/vector_utils.h b/src/core/search/vector_utils.h index ea19db478ab2..61c95a680891 100644 --- a/src/core/search/vector_utils.h +++ b/src/core/search/vector_utils.h @@ -10,6 +10,10 @@ namespace dfly::search { OwnedFtVector BytesToFtVector(std::string_view value); +// Returns std::nullopt if value can not be converted to the vector +// TODO: Remove unsafe version +std::optional BytesToFtVectorSafe(std::string_view value); + float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim); } // namespace dfly::search diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index b256647fbf97..cfd7c5e16ce5 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -38,43 +38,44 @@ string_view SdsToSafeSv(sds str) { return str != nullptr ? string_view{str, sdslen(str)} : ""sv; } -search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) { +using FieldValue = std::optional; + +FieldValue ToSortableValue(search::SchemaField::FieldType type, string_view value) { + if (value.empty()) { + return std::nullopt; + } + if (type == search::SchemaField::NUMERIC) { - double value_as_double = 0; - if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double + auto value_as_double = search::ParseNumericField(value); + if (!value_as_double) { // temporary convert to double LOG(DFATAL) << "Failed to convert " << value << " to double"; + return std::nullopt; } - return value_as_double; + return value_as_double.value(); } if (type == search::SchemaField::VECTOR) { - auto [ptr, size] = search::BytesToFtVector(value); + auto opt_vector = search::BytesToFtVectorSafe(value); + if (!opt_vector) { + LOG(DFATAL) << "Failed to convert " << value << " to vector"; + return std::nullopt; + } + auto& [ptr, size] = opt_vector.value(); return absl::StrCat("[", absl::StrJoin(absl::Span{ptr.get(), size}, ","), "]"); } return string{value}; } -search::SortableValue JsonToSortableValue(const search::SchemaField::FieldType type, - const JsonType& json) { - if (type == search::SchemaField::NUMERIC) { - return json.as_double(); - } - return json.to_string(); -} - -search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key, - string_view value) { +FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, string_view value) { auto it = schema.fields.find(key); if (it == schema.fields.end()) - return FieldToSortableValue(search::SchemaField::TEXT, value); - return FieldToSortableValue(it->second.type, value); + return ToSortableValue(search::SchemaField::TEXT, value); + return ToSortableValue(it->second.type, value); } -search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key, - const JsonType& json) { - auto it = schema.fields.find(key); - if (it == schema.fields.end()) - return JsonToSortableValue(search::SchemaField::TEXT, json); - return JsonToSortableValue(it->second.type, json); +FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key, + const JsonType& json) { + auto json_as_string = json.to_string(); + return ExtractSortableValue(schema, key, json_as_string); } } // namespace @@ -83,7 +84,11 @@ SearchDocData BaseAccessor::Serialize( const search::Schema& schema, absl::Span> fields) const { SearchDocData out{}; for (const auto& [fident, fname] : fields) { - out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ",")); + auto field_value = + ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ",")); + if (field_value) { + out[fname] = std::move(field_value).value(); + } } return out; } @@ -92,14 +97,39 @@ SearchDocData BaseAccessor::SerializeDocument(const search::Schema& schema) cons return Serialize(schema); } -BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const { - auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data()); - return strsv.has_value() ? StringList{*strsv} : StringList{}; +std::optional BaseAccessor::GetVector( + std::string_view active_field) const { + auto strings_list = GetStrings(active_field); + if (strings_list) { + return !strings_list->empty() ? search::BytesToFtVectorSafe(strings_list->front()) + : VectorInfo{}; + } + return std::nullopt; } -BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const { - auto strlist = GetStrings(active_field); - return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front()); +std::optional BaseAccessor::GetNumbers( + std::string_view active_field) const { + auto strings_list = GetStrings(active_field); + if (!strings_list) { + return std::nullopt; + } + + NumsList nums_list; + nums_list.reserve(strings_list->size()); + for (auto str : strings_list.value()) { + auto num = search::ParseNumericField(str); + if (!num) { + return std::nullopt; + } + nums_list.push_back(num.value()); + } + return nums_list; +} + +std::optional ListPackAccessor::GetStrings( + string_view active_field) const { + auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data()); + return strsv.has_value() ? StringList{*strsv} : StringList{}; } SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { @@ -114,27 +144,29 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { string_view v = container_utils::LpGetView(fptr, intbuf_[1].data()); fptr = lpNext(lp_, fptr); - out[k] = ExtractSortableValue(schema, k, v); + auto field_value = ExtractSortableValue(schema, k, v); + if (field_value) { + out[k] = std::move(field_value).value(); + } } return out; } -BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const { +std::optional StringMapAccessor::GetStrings( + string_view active_field) const { auto it = hset_->Find(active_field); return it != hset_->end() ? StringList{SdsToSafeSv(it->second)} : StringList{}; } -BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const { - auto strlist = GetStrings(active_field); - return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front()); -} - SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const { SearchDocData out{}; - for (const auto& [kptr, vptr] : *hset_) - out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr)); - + for (const auto& [kptr, vptr] : *hset_) { + auto field_value = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr)); + if (field_value) { + out[SdsToSafeSv(kptr)] = std::move(field_value).value(); + } + } return out; } @@ -159,27 +191,54 @@ struct JsonAccessor::JsonPathContainer { variant> val; }; -BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const { +std::optional JsonAccessor::GetStrings(string_view active_field) const { auto* path = GetPath(active_field); if (!path) - return {}; + return search::EmptyAccessResult(); auto path_res = path->Evaluate(json_); if (path_res.empty()) - return {}; + return search::EmptyAccessResult(); + + if (path_res.size() == 1 && !path_res[0].is_array()) { + if (!path_res[0].is_string()) + return std::nullopt; - if (path_res.size() == 1) { buf_ = path_res[0].as_string(); - return {buf_}; + return StringList{buf_}; } + buf_.clear(); // First, grow buffer and compute string sizes vector sizes; - for (const auto& element : path_res) { + + auto add_json_to_buf = [&](const JsonType& json) { size_t start = buf_.size(); - buf_ += element.as_string(); + buf_ += json.as_string(); sizes.push_back(buf_.size() - start); + }; + + if (!path_res[0].is_array()) { + sizes.reserve(path_res.size()); + for (const auto& element : path_res) { + if (!element.is_string()) + return std::nullopt; + + add_json_to_buf(element); + } + } else { + if (path_res.size() > 1) { + return std::nullopt; + } + + sizes.reserve(path_res[0].size()); + for (const auto& element : path_res[0].array_range()) { + if (!element.is_string()) + return std::nullopt; + + add_json_to_buf(element); + } } // Reposition start pointers to the most recent allocation of buf @@ -194,23 +253,62 @@ BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) cons return out; } -BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const { +std::optional JsonAccessor::GetVector(string_view active_field) const { auto* path = GetPath(active_field); if (!path) - return {}; + return VectorInfo{}; auto res = path->Evaluate(json_); if (res.empty()) - return {nullptr, 0}; + return VectorInfo{}; + + if (!res[0].is_array()) + return std::nullopt; size_t size = res[0].size(); auto ptr = make_unique(size); size_t i = 0; - for (const auto& v : res[0].array_range()) + for (const auto& v : res[0].array_range()) { + if (!v.is_number()) { + return std::nullopt; + } ptr[i++] = v.as(); + } - return {std::move(ptr), size}; + return BaseAccessor::VectorInfo{std::move(ptr), size}; +} + +std::optional JsonAccessor::GetNumbers(string_view active_field) const { + auto* path = GetPath(active_field); + if (!path) + return search::EmptyAccessResult(); + + auto path_res = path->Evaluate(json_); + if (path_res.empty()) + return search::EmptyAccessResult(); + + NumsList nums_list; + if (!path_res[0].is_array()) { + nums_list.reserve(path_res.size()); + for (const auto& element : path_res) { + if (!element.is_number()) + return std::nullopt; + nums_list.push_back(element.as()); + } + } else { + if (path_res.size() > 1) { + return std::nullopt; + } + + nums_list.reserve(path_res[0].size()); + for (const auto& element : path_res[0].array_range()) { + if (!element.is_number()) + return std::nullopt; + nums_list.push_back(element.as()); + } + } + return nums_list; } JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const { @@ -259,8 +357,12 @@ SearchDocData JsonAccessor::Serialize( SearchDocData out{}; for (const auto& [ident, name] : fields) { if (auto* path = GetPath(ident); path) { - if (auto res = path->Evaluate(json_); !res.empty()) - out[name] = ExtractSortableValueFromJson(schema, ident, res[0]); + if (auto res = path->Evaluate(json_); !res.empty()) { + auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]); + if (field_value) { + out[name] = std::move(field_value).value(); + } + } } } return out; diff --git a/src/server/search/doc_accessors.h b/src/server/search/doc_accessors.h index 8a8ab5ae6df6..4a6dea1c0c71 100644 --- a/src/server/search/doc_accessors.h +++ b/src/server/search/doc_accessors.h @@ -12,6 +12,7 @@ #include "core/json/json_object.h" #include "core/search/search.h" +#include "core/search/vector_utils.h" #include "server/common.h" #include "server/search/doc_index.h" #include "server/table.h" @@ -37,6 +38,10 @@ struct BaseAccessor : public search::DocumentAccessor { indexed field */ virtual SearchDocData SerializeDocument(const search::Schema& schema) const; + + // Default implementation uses GetStrings + virtual std::optional GetVector(std::string_view active_field) const; + virtual std::optional GetNumbers(std::string_view active_field) const; }; // Accessor for hashes stored with listpack @@ -46,8 +51,7 @@ struct ListPackAccessor : public BaseAccessor { explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} { } - StringList GetStrings(std::string_view field) const override; - VectorInfo GetVector(std::string_view field) const override; + std::optional GetStrings(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; private: @@ -60,8 +64,7 @@ struct StringMapAccessor : public BaseAccessor { explicit StringMapAccessor(StringMap* hset) : hset_{hset} { } - StringList GetStrings(std::string_view field) const override; - VectorInfo GetVector(std::string_view field) const override; + std::optional GetStrings(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; private: @@ -75,8 +78,9 @@ struct JsonAccessor : public BaseAccessor { explicit JsonAccessor(const JsonType* json) : json_{*json} { } - StringList GetStrings(std::string_view field) const override; - VectorInfo GetVector(std::string_view field) const override; + std::optional GetStrings(std::string_view field) const override; + std::optional GetVector(std::string_view field) const override; + std::optional GetNumbers(std::string_view active_field) const override; // The JsonAccessor works with structured types and not plain strings, so an overload is needed SearchDocData Serialize(const search::Schema& schema, diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 5835971eb76f..e24acceeda45 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -41,7 +41,7 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) { return; auto accessor = GetAccessor(op_args.db_cntx, pv); - f(key, accessor.get()); + f(key, *accessor); }; PrimeTable::Cursor cursor; @@ -146,12 +146,14 @@ ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Add(string_view key) { return id; } -ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Remove(string_view key) { - DCHECK_GT(ids_.count(key), 0u); +std::optional ShardDocIndex::DocKeyIndex::Remove(string_view key) { + auto it = ids_.extract(key); + if (!it) { + return std::nullopt; + } - DocId id = ids_.find(key)->second; + const DocId id = it.mapped(); keys_[id] = ""; - ids_.erase(key); free_ids_.push_back(id); return id; @@ -184,7 +186,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) key_index_ = DocKeyIndex{}; indices_.emplace(base_->schema, base_->options, mr); - auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); }; + auto cb = [this](string_view key, const BaseAccessor& doc) { + DocId id = key_index_.Add(key); + if (!indices_->Add(id, doc)) { + key_index_.Remove(key); + } + }; + TraverseAllMatching(*base_, op_args, cb); VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix; @@ -195,7 +203,10 @@ void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const Prim return; auto accessor = GetAccessor(db_cntx, pv); - indices_->Add(key_index_.Add(key), accessor.get()); + DocId id = key_index_.Add(key); + if (!indices_->Add(id, *accessor)) { + key_index_.Remove(key); + } } void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { @@ -203,8 +214,10 @@ void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const P return; auto accessor = GetAccessor(db_cntx, pv); - DocId id = key_index_.Remove(key); - indices_->Remove(id, accessor.get()); + auto id = key_index_.Remove(key); + if (id) { + indices_->Remove(id.value(), *accessor); + } } bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index 564ca6193540..1c775f30c245 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -133,7 +133,7 @@ class ShardDocIndex { // DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface. struct DocKeyIndex { DocId Add(std::string_view key); - DocId Remove(std::string_view key); + std::optional Remove(std::string_view key); std::string_view Get(DocId id) const; size_t Size() const; diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 2fd8406937ea..804a3ed13b9e 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -90,6 +90,13 @@ template auto IsArray(Args... args) { template auto IsUnordArray(Args... args) { return RespArray(UnorderedElementsAre(std::forward(args)...)); } +template +void BuildKvMatchers(std::vector>>& kv_matchers, + const Expected& expected, std::index_sequence) { + std::initializer_list{ + (kv_matchers.emplace_back(Pair(std::get(expected), std::get(expected))), + 0)...}; +} MATCHER_P(IsMapMatcher, expected, "") { if (arg.type != RespExpr::ARRAY) { @@ -97,108 +104,96 @@ MATCHER_P(IsMapMatcher, expected, "") { return false; } + constexpr size_t expected_size = std::tuple_size::value; + constexpr size_t exprected_pairs_number = expected_size / 2; + auto result = arg.GetVec(); - if (result.size() != expected.size()) { + if (result.size() != expected_size) { *result_listener << "Wrong resp array size: " << result.size(); return false; } - using KeyValueArray = std::vector>; - - KeyValueArray received_pairs; + std::vector> received_pairs; for (size_t i = 0; i < result.size(); i += 2) { - received_pairs.emplace_back(result[i].GetString(), result[i + 1].GetString()); - } - - KeyValueArray expected_pairs; - for (size_t i = 0; i < expected.size(); i += 2) { - expected_pairs.emplace_back(expected[i], expected[i + 1]); + received_pairs.emplace_back(result[i].GetString(), result[i + 1]); } - // Custom unordered comparison - std::sort(received_pairs.begin(), received_pairs.end()); - std::sort(expected_pairs.begin(), expected_pairs.end()); + std::vector>> kv_matchers; + BuildKvMatchers(kv_matchers, expected, std::make_index_sequence{}); - return received_pairs == expected_pairs; + return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs, + result_listener); } -template auto IsMap(Matchers... matchers) { - return IsMapMatcher(std::vector{std::forward(matchers)...}); +template auto IsMap(Args... args) { + return IsMapMatcher(std::make_tuple(args...)); } -MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") { +MATCHER_P(IsMapWithSizeMatcher, expected, "") { if (arg.type != RespExpr::ARRAY) { *result_listener << "Wrong response type: " << arg.type; return false; } + constexpr size_t expected_size = std::tuple_size::value; + constexpr size_t exprected_pairs_number = expected_size / 2; auto result = arg.GetVec(); - size_t expected_size = std::tuple_size::value; - if (result.size() != expected_size + 1) { + if (result.size() != expected_size + 1 || result.size() % 2 != 1) { *result_listener << "Wrong resp array size: " << result.size(); return false; } - if (result[0].GetInt() != expected_size) { - *result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1); + if (result[0].GetInt() != exprected_pairs_number) { + *result_listener << "Wrong pairs count: " << result[0].GetInt().value_or(-1); return false; } - std::vector received_elements(result.begin() + 1, result.end()); + std::vector> received_pairs; + for (size_t i = 1; i < result.size(); i += 2) { + received_pairs.emplace_back(result[i].GetString(), result[i + 1]); + } - // Create a vector of matchers from the tuple - std::vector> matchers; - std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected); + std::vector>> kv_matchers; + BuildKvMatchers(kv_matchers, expected, std::make_index_sequence{}); - return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements, + return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs, result_listener); } -template auto IsUnordArrayWithSize(Matchers... matchers) { - return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...)); -} - -template -void BuildKvMatchers(std::vector>>& kv_matchers, - const Expected& expected, std::index_sequence) { - std::initializer_list{ - (kv_matchers.emplace_back(Pair(std::get(expected), std::get(expected))), - 0)...}; +template auto IsMapWithSize(Args... args) { + return IsMapWithSizeMatcher(std::make_tuple(args...)); } -MATCHER_P(IsMapWithSizeMatcher, expected, "") { +MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") { if (arg.type != RespExpr::ARRAY) { *result_listener << "Wrong response type: " << arg.type; return false; } - constexpr size_t expected_size = std::tuple_size::value; - constexpr size_t exprected_pairs_number = expected_size / 2; auto result = arg.GetVec(); - if (result.size() != expected_size + 1 || result.size() % 2 != 1) { + size_t expected_size = std::tuple_size::value; + if (result.size() != expected_size + 1) { *result_listener << "Wrong resp array size: " << result.size(); return false; } - if (result[0].GetInt() != exprected_pairs_number) { - *result_listener << "Wrong pairs count: " << result[0].GetInt().value_or(-1); + if (result[0].GetInt() != expected_size) { + *result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1); return false; } - std::vector> received_pairs; - for (size_t i = 1; i < result.size(); i += 2) { - received_pairs.emplace_back(result[i].GetString(), result[i + 1]); - } + std::vector received_elements(result.begin() + 1, result.end()); - std::vector>> kv_matchers; - BuildKvMatchers(kv_matchers, expected, std::make_index_sequence{}); + // Create a vector of matchers from the tuple + std::vector> matchers; + std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected); - return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs, + return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements, result_listener); } -template auto IsMapWithSize(Args... args) { - return IsMapWithSizeMatcher(std::make_tuple(args...)); +template auto IsUnordArrayWithSize(Matchers... matchers) { + return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...)); } TEST_F(SearchFamilyTest, CreateDropListIndex) { @@ -649,7 +644,7 @@ TEST_F(SearchFamilyTest, TestReturn) { // Check non-existing field resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"}); - EXPECT_THAT(resp, MatchEntry("k0", "nothere", "")); + EXPECT_THAT(resp, MatchEntry("k0")); // Checl implcit __vector_score is provided float score = 20; @@ -1194,8 +1189,8 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) { IsMap("foo_total", "10", "word", "item1"))); // Test JSON - Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":"10","text":"first key"})"}); - Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":"20","text":"second key"})"}); + Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":10,"text":"first key"})"}); + Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":20,"text":"second key"})"}); resp = Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.word", "AS", "word", "TAG", "$.foo", "AS", "foo", "NUMERIC", "$.text", "AS", "text", "TEXT"}); @@ -1214,4 +1209,220 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) { } #endif +TEST_F(SearchFamilyTest, WrongFieldTypeJson) { + // Test simple + Run({"JSON.SET", "j1", ".", R"({"value":"one"})"}); + Run({"JSON.SET", "j2", ".", R"({"value":1})"}); + + EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC", + "SORTABLE"}), + "OK"); + + auto resp = Run({"FT.SEARCH", "i1", "*"}); + EXPECT_THAT(resp, AreDocIds("j2")); + + resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "$.value"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("$.value", "1"))); + + // Test with two fields. One is loading + Run({"JSON.SET", "j3", ".", R"({"value":"two","another_value":1})"}); + Run({"JSON.SET", "j4", ".", R"({"value":2,"another_value":2})"}); + + EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC"}), + "OK"); + + resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "$.another_value"}); + EXPECT_THAT( + resp, IsMapWithSize("j2", IsMap("$", R"({"value":1})"), "j4", + IsMap("$", R"({"another_value":2,"value":2})", "$.another_value", "2"))); + + resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "$.value", "$.another_value", "GROUPBY", "2", + "$.value", "$.another_value", "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, + IsUnordArrayWithSize( + IsMap("$.value", "1", "$.another_value", ArgType(RespExpr::NIL), "count", "1"), + IsMap("$.value", "2", "$.another_value", "2", "count", "1"))); + + // Test multiple field values + Run({"JSON.SET", "j5", ".", R"({"arr":[{"id":1},{"id":"two"}]})"}); + Run({"JSON.SET", "j6", ".", R"({"arr":[{"id":1},{"id":2}]})"}); + Run({"JSON.SET", "j7", ".", R"({"arr":[]})"}); + + resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.SEARCH", "i3", "*"}); + EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails + + resp = Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC", + "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.SEARCH", "i4", "*"}); + EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails +} + +TEST_F(SearchFamilyTest, WrongFieldTypeHash) { + // Test simple + Run({"HSET", "h1", "value", "one"}); + Run({"HSET", "h2", "value", "1"}); + + EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "value", "NUMERIC", "SORTABLE"}), "OK"); + + auto resp = Run({"FT.SEARCH", "i1", "*"}); + EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1"))); + + resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "@value"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("value", "1"))); + + // Test with two fields. One is loading + Run({"HSET", "h3", "value", "two", "another_value", "1"}); + Run({"HSET", "h4", "value", "2", "another_value", "2"}); + + EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "value", "NUMERIC"}), "OK"); + + resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "@another_value"}); + EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1"), "h4", + IsMap("value", "2", "another_value", "2"))); + + resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "@value", "@another_value", "GROUPBY", "2", + "@value", "@another_value", "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("value", "1", "another_value", ArgType(RespExpr::NIL), "count", "1"), + IsMap("value", "2", "another_value", "2", "count", "1"))); +} + +TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) { + Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"}); + Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"}); + Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"}); + Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"}); + Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"}); + Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"}); + + auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"}); + EXPECT_EQ(resp, "OK"); + + resp = Run( + {"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG"}); + EXPECT_EQ(resp, "OK"); + + resp = + Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i5", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT"}); + EXPECT_EQ(resp, "OK"); + + resp = + Run({"FT.CREATE", "i6", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i7", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "VECTOR", "FLAT", + "6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.SEARCH", "i1", "*"}); + EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5")); + + resp = Run({"FT.SEARCH", "i2", "*"}); + EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5")); + + resp = Run({"FT.SEARCH", "i3", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4")); + + resp = Run({"FT.SEARCH", "i4", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4")); + + resp = Run({"FT.SEARCH", "i5", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6")); + + resp = Run({"FT.SEARCH", "i6", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6")); + + resp = Run({"FT.SEARCH", "i7", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j5")); +} + +TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) { + Run({"HSET", "j1", "data", "1", "name", "doc_with_int"}); + Run({"HSET", "j2", "data", "1", "name", "doc_with_int_as_string"}); + Run({"HSET", "j3", "data", "string", "name", "doc_with_string"}); + Run({"HSET", "j4", "name", "no_data"}); + Run({"HSET", "j5", "data", "5,4,3", "name", "doc_with_fake_vector"}); + Run({"HSET", "j6", "data", "[5,4,3]", "name", "doc_with_fake_vector_as_string"}); + + // Vector [1, 2, 3] + std::string vector = std::string("\x3f\x80\x00\x00\x40\x00\x00\x00\x40\x40\x00\x00", 12); + Run({"HSET", "j7", "data", vector, "name", "doc_with_vector [1, 2, 3]"}); + + auto resp = Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "data", "NUMERIC"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "data", "NUMERIC", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i3", "ON", "HASH", "SCHEMA", "data", "TAG"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i4", "ON", "HASH", "SCHEMA", "data", "TAG", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i5", "ON", "HASH", "SCHEMA", "data", "TEXT"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i6", "ON", "HASH", "SCHEMA", "data", "TEXT", "SORTABLE"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.CREATE", "i7", "ON", "HASH", "SCHEMA", "data", "VECTOR", "FLAT", "6", "TYPE", + "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.SEARCH", "i1", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4")); + + resp = Run({"FT.SEARCH", "i2", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4")); + + resp = Run({"FT.SEARCH", "i3", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5")); + + resp = Run({"FT.SEARCH", "i4", "*"}); + EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5")); + + resp = Run({"FT.SEARCH", "i5", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5")); + + resp = Run({"FT.SEARCH", "i6", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5")); + + resp = Run({"FT.SEARCH", "i7", "*"}); + EXPECT_THAT(resp, AreDocIds("j4", "j7")); +} + +TEST_F(SearchFamilyTest, WrongVectorFieldType) { + Run({"JSON.SET", "j1", ".", + R"({"vector_field": [0.1, 0.2, 0.3], "name": "doc_with_correct_dim"})"}); + Run({"JSON.SET", "j2", ".", R"({"vector_field": [0.1, 0.2], "name": "doc_with_small_dim"})"}); + Run({"JSON.SET", "j3", ".", + R"({"vector_field": [0.1, 0.2, 0.3, 0.4], "name": "doc_with_large_dim"})"}); + Run({"JSON.SET", "j4", ".", R"({"vector_field": [1, 2, 3], "name": "doc_with_int_values"})"}); + Run({"JSON.SET", "j5", ".", + R"({"vector_field":"not_vector", "name":"doc_with_incorrect_field_type"})"}); + Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"}); + Run({"JSON.SET", "j7", ".", + R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"}); + + auto resp = + Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field", + "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.SEARCH", "index", "*"}); + EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4")); +} + } // namespace dfly