From 01dc67137c878602790aae41bdea9aad664657c8 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 11 Nov 2024 14:32:52 +0200 Subject: [PATCH] fix: nullptr access within reply builder Signed-off-by: kostas --- src/facade/reply_builder.h | 30 +++++++++++++++++++++++++++--- src/facade/reply_capture.cc | 14 ++++++++++++++ src/facade/reply_capture.h | 10 +++++++++- src/server/http_api.cc | 4 ++++ src/server/string_family.cc | 15 ++++++--------- 5 files changed, 60 insertions(+), 13 deletions(-) diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 762d2835e61f..a510fbcbf7a7 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -32,6 +32,7 @@ class SinkReplyBuilder { public: constexpr static size_t kMaxInlineSize = 32; constexpr static size_t kMaxBufferSize = 8192; + enum CollectionType { ARRAY, SET, MAP, PUSH }; explicit SinkReplyBuilder(io::Sink* sink) : sink_(sink) { } @@ -103,6 +104,13 @@ class SinkReplyBuilder { void SendError(ErrorReply error); virtual void SendProtocolError(std::string_view str) = 0; + virtual void StartCollection(unsigned len, CollectionType ct) = 0; + virtual void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, + uint32_t mc_flag) = 0; + + virtual void SendNull() = 0; + virtual void SendBulkString(std::string_view str) = 0; // RESP: Blob String + std::string ConsumeLastError() { return std::exchange(last_error_, {}); } @@ -157,12 +165,25 @@ class MCReplyBuilder : public SinkReplyBuilder { void SendClientError(std::string_view str); void SendNotFound(); - void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, uint32_t mc_flag); + void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, + uint32_t mc_flag) final; void SendSimpleString(std::string_view str) final; void SendProtocolError(std::string_view str) final; void SendRaw(std::string_view str); + void StartCollection(unsigned len, CollectionType ct) override { + ABSL_UNREACHABLE(); + } + + void SendNull() override { + ABSL_UNREACHABLE(); + } + + void SendBulkString(std::string_view str) override { + ABSL_UNREACHABLE(); + } + void SetNoreply(bool noreply) { noreply_ = noreply; } @@ -178,7 +199,6 @@ class MCReplyBuilder : public SinkReplyBuilder { // Redis reply builder interface for sending RESP data. class RedisReplyBuilderBase : public SinkReplyBuilder { public: - enum CollectionType { ARRAY, SET, MAP, PUSH }; enum VerbatimFormat { TXT, MARKDOWN }; explicit RedisReplyBuilderBase(io::Sink* sink) : SinkReplyBuilder(sink) { @@ -199,7 +219,7 @@ class RedisReplyBuilderBase : public SinkReplyBuilder { virtual void SendDouble(double val); // RESP: Number virtual void SendNullArray(); - virtual void StartCollection(unsigned len, CollectionType ct); + virtual void StartCollection(unsigned len, CollectionType ct) override; using SinkReplyBuilder::SendError; void SendError(std::string_view str, std::string_view type = {}) override; @@ -235,6 +255,10 @@ class RedisReplyBuilder : public RedisReplyBuilderBase { void SendSimpleStrArr(const facade::ArgRange& strs); void SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct = ARRAY); void SendScoredArray(absl::Span> arr, bool with_scores); + void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, + uint32_t mc_flag) override { + ABSL_UNREACHABLE(); + } void SendStored() final; void SendSetSkipped() final; diff --git a/src/facade/reply_capture.cc b/src/facade/reply_capture.cc index 02d00356e885..520a2fb0342b 100644 --- a/src/facade/reply_capture.cc +++ b/src/facade/reply_capture.cc @@ -53,6 +53,12 @@ void CapturingReplyBuilder::SendBulkString(std::string_view str) { Capture(BulkString{string{str}}); } +void CapturingReplyBuilder::SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, + uint32_t mc_flag) { + SKIP_LESS(ReplyMode::FULL); + Capture(Value{std::string(key), std::string(value), mc_ver, mc_flag}); +} + void CapturingReplyBuilder::StartCollection(unsigned len, CollectionType type) { SKIP_LESS(ReplyMode::FULL); stack_.emplace(make_unique(len, type), type == MAP ? len * 2 : len); @@ -126,6 +132,14 @@ struct CaptureVisitor { rb->SendBulkString(bs); } + void operator()(const CapturingReplyBuilder::Value& bv) { + if (rb->GetProtocol() == Protocol::MEMCACHE) { + rb->SendValue(bv.key, bv.value, bv.mc_ver, bv.mc_flag); + } else { + rb->SendBulkString(bv.value); + } + } + void operator()(CapturingReplyBuilder::Null) { rb->SendNull(); } diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index 1ffa4fd7cfce..d6c80df0ecdd 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -30,6 +30,8 @@ class CapturingReplyBuilder : public RedisReplyBuilder { void SendDouble(double val) override; void SendSimpleString(std::string_view str) override; void SendBulkString(std::string_view str) override; + void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, + uint32_t mc_flag) override; void StartCollection(unsigned len, CollectionType type) override; void SendNullArray() override; @@ -42,13 +44,19 @@ class CapturingReplyBuilder : public RedisReplyBuilder { struct CollectionPayload; struct SimpleString : public std::string {}; // SendSimpleString struct BulkString : public std::string {}; // SendBulkString + struct Value { + std::string key; + std::string value; + uint64_t mc_ver = 0; // 0 means we do not output it (i.e has not been requested). + uint32_t mc_flag = 0; + }; CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL) : RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} { } using Payload = std::variant>; + std::unique_ptr, Value>; // Non owned Error based on SendError arguments (msg, type) using ErrorRef = std::pair; diff --git a/src/server/http_api.cc b/src/server/http_api.cc index 8aa4c7b64cee..c3d0800a1dbd 100644 --- a/src/server/http_api.cc +++ b/src/server/http_api.cc @@ -123,6 +123,10 @@ struct CaptureVisitor { absl::StrAppend(&str, "null"); } + void operator()(CapturingReplyBuilder::Value) { + ABSL_UNREACHABLE(); + } + void operator()(CapturingReplyBuilder::Error err) { str = absl::StrCat(R"({"error": ")", err.first, "\""); } diff --git a/src/server/string_family.cc b/src/server/string_family.cc index 22bbb65e1e9a..319eaa44e5b8 100644 --- a/src/server/string_family.cc +++ b/src/server/string_family.cc @@ -1313,23 +1313,20 @@ void StringFamily::MGet(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil } } - SinkReplyBuilder::ReplyScope scope(builder); if (builder->GetProtocol() == Protocol::MEMCACHE) { - auto* rb = static_cast(builder); for (const auto& entry : res) { if (!entry) continue; - rb->SendValue(entry->key, entry->value, entry->mc_ver, entry->mc_flag); + builder->SendValue(entry->key, entry->value, entry->mc_ver, entry->mc_flag); } - rb->SendSimpleString("END"); + builder->SendSimpleString("END"); } else { - auto* rb = static_cast(builder); - rb->StartArray(res.size()); + builder->StartCollection(res.size(), SinkReplyBuilder::CollectionType::ARRAY); for (const auto& entry : res) { if (entry) - rb->SendBulkString(entry->value); + builder->SendBulkString(entry->value); else - rb->SendNull(); + builder->SendNull(); } } } @@ -1588,7 +1585,7 @@ void StringFamily::Register(CommandRegistry* registry) { << CI{"SUBSTR", CO::READONLY, 4, 1, 1}.HFUNC(GetRange) // Alias for GetRange << CI{"SETRANGE", CO::WRITE | CO::DENYOOM, 4, 1, 1}.HFUNC(SetRange) << CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, acl::THROTTLE}.HFUNC( - ClThrottle); + ClThrottle); } } // namespace dfly