Skip to content

Commit

Permalink
fix: nullptr access within reply builder
Browse files Browse the repository at this point in the history
Signed-off-by: kostas <[email protected]>
  • Loading branch information
kostasrim committed Nov 11, 2024
1 parent 79aa5d4 commit 01dc671
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 13 deletions.
30 changes: 27 additions & 3 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SinkReplyBuilder {
public:
constexpr static size_t kMaxInlineSize = 32;
constexpr static size_t kMaxBufferSize = 8192;
enum CollectionType { ARRAY, SET, MAP, PUSH };

explicit SinkReplyBuilder(io::Sink* sink) : sink_(sink) {
}
Expand Down Expand Up @@ -103,6 +104,13 @@ class SinkReplyBuilder {
void SendError(ErrorReply error);
virtual void SendProtocolError(std::string_view str) = 0;

virtual void StartCollection(unsigned len, CollectionType ct) = 0;
virtual void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver,
uint32_t mc_flag) = 0;

virtual void SendNull() = 0;
virtual void SendBulkString(std::string_view str) = 0; // RESP: Blob String

std::string ConsumeLastError() {
return std::exchange(last_error_, {});
}
Expand Down Expand Up @@ -157,12 +165,25 @@ class MCReplyBuilder : public SinkReplyBuilder {
void SendClientError(std::string_view str);
void SendNotFound();

void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver, uint32_t mc_flag);
void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver,
uint32_t mc_flag) final;
void SendSimpleString(std::string_view str) final;
void SendProtocolError(std::string_view str) final;

void SendRaw(std::string_view str);

void StartCollection(unsigned len, CollectionType ct) override {
ABSL_UNREACHABLE();
}

void SendNull() override {
ABSL_UNREACHABLE();
}

void SendBulkString(std::string_view str) override {
ABSL_UNREACHABLE();
}

void SetNoreply(bool noreply) {
noreply_ = noreply;
}
Expand All @@ -178,7 +199,6 @@ class MCReplyBuilder : public SinkReplyBuilder {
// Redis reply builder interface for sending RESP data.
class RedisReplyBuilderBase : public SinkReplyBuilder {
public:
enum CollectionType { ARRAY, SET, MAP, PUSH };
enum VerbatimFormat { TXT, MARKDOWN };

explicit RedisReplyBuilderBase(io::Sink* sink) : SinkReplyBuilder(sink) {
Expand All @@ -199,7 +219,7 @@ class RedisReplyBuilderBase : public SinkReplyBuilder {
virtual void SendDouble(double val); // RESP: Number

virtual void SendNullArray();
virtual void StartCollection(unsigned len, CollectionType ct);
virtual void StartCollection(unsigned len, CollectionType ct) override;

using SinkReplyBuilder::SendError;
void SendError(std::string_view str, std::string_view type = {}) override;
Expand Down Expand Up @@ -235,6 +255,10 @@ class RedisReplyBuilder : public RedisReplyBuilderBase {
void SendSimpleStrArr(const facade::ArgRange& strs);
void SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct = ARRAY);
void SendScoredArray(absl::Span<const std::pair<std::string, double>> arr, bool with_scores);
void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver,
uint32_t mc_flag) override {
ABSL_UNREACHABLE();
}

void SendStored() final;
void SendSetSkipped() final;
Expand Down
14 changes: 14 additions & 0 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ void CapturingReplyBuilder::SendBulkString(std::string_view str) {
Capture(BulkString{string{str}});
}

void CapturingReplyBuilder::SendValue(std::string_view key, std::string_view value, uint64_t mc_ver,
uint32_t mc_flag) {
SKIP_LESS(ReplyMode::FULL);
Capture(Value{std::string(key), std::string(value), mc_ver, mc_flag});
}

void CapturingReplyBuilder::StartCollection(unsigned len, CollectionType type) {
SKIP_LESS(ReplyMode::FULL);
stack_.emplace(make_unique<CollectionPayload>(len, type), type == MAP ? len * 2 : len);
Expand Down Expand Up @@ -126,6 +132,14 @@ struct CaptureVisitor {
rb->SendBulkString(bs);
}

void operator()(const CapturingReplyBuilder::Value& bv) {
if (rb->GetProtocol() == Protocol::MEMCACHE) {
rb->SendValue(bv.key, bv.value, bv.mc_ver, bv.mc_flag);
} else {
rb->SendBulkString(bv.value);
}
}

void operator()(CapturingReplyBuilder::Null) {
rb->SendNull();
}
Expand Down
10 changes: 9 additions & 1 deletion src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
void SendDouble(double val) override;
void SendSimpleString(std::string_view str) override;
void SendBulkString(std::string_view str) override;
void SendValue(std::string_view key, std::string_view value, uint64_t mc_ver,
uint32_t mc_flag) override;

void StartCollection(unsigned len, CollectionType type) override;
void SendNullArray() override;
Expand All @@ -42,13 +44,19 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
struct CollectionPayload;
struct SimpleString : public std::string {}; // SendSimpleString
struct BulkString : public std::string {}; // SendBulkString
struct Value {
std::string key;
std::string value;
uint64_t mc_ver = 0; // 0 means we do not output it (i.e has not been requested).
uint32_t mc_flag = 0;
};

CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
}

using Payload = std::variant<std::monostate, Null, Error, long, double, SimpleString, BulkString,
std::unique_ptr<CollectionPayload>>;
std::unique_ptr<CollectionPayload>, Value>;

// Non owned Error based on SendError arguments (msg, type)
using ErrorRef = std::pair<std::string_view, std::string_view>;
Expand Down
4 changes: 4 additions & 0 deletions src/server/http_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ struct CaptureVisitor {
absl::StrAppend(&str, "null");
}

void operator()(CapturingReplyBuilder::Value) {
ABSL_UNREACHABLE();
}

void operator()(CapturingReplyBuilder::Error err) {
str = absl::StrCat(R"({"error": ")", err.first, "\"");
}
Expand Down
15 changes: 6 additions & 9 deletions src/server/string_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1313,23 +1313,20 @@ void StringFamily::MGet(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil
}
}

SinkReplyBuilder::ReplyScope scope(builder);
if (builder->GetProtocol() == Protocol::MEMCACHE) {
auto* rb = static_cast<MCReplyBuilder*>(builder);
for (const auto& entry : res) {
if (!entry)
continue;
rb->SendValue(entry->key, entry->value, entry->mc_ver, entry->mc_flag);
builder->SendValue(entry->key, entry->value, entry->mc_ver, entry->mc_flag);
}
rb->SendSimpleString("END");
builder->SendSimpleString("END");
} else {
auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(res.size());
builder->StartCollection(res.size(), SinkReplyBuilder::CollectionType::ARRAY);
for (const auto& entry : res) {
if (entry)
rb->SendBulkString(entry->value);
builder->SendBulkString(entry->value);
else
rb->SendNull();
builder->SendNull();
}
}
}
Expand Down Expand Up @@ -1588,7 +1585,7 @@ void StringFamily::Register(CommandRegistry* registry) {
<< CI{"SUBSTR", CO::READONLY, 4, 1, 1}.HFUNC(GetRange) // Alias for GetRange
<< CI{"SETRANGE", CO::WRITE | CO::DENYOOM, 4, 1, 1}.HFUNC(SetRange)
<< CI{"CL.THROTTLE", CO::WRITE | CO::DENYOOM | CO::FAST, -5, 1, 1, acl::THROTTLE}.HFUNC(
ClThrottle);
ClThrottle);
}

} // namespace dfly

0 comments on commit 01dc671

Please sign in to comment.