From aef425ec0b3a15c28a56c6b6d7a78d3f37d1388c Mon Sep 17 00:00:00 2001 From: baranov Date: Tue, 22 Oct 2024 11:01:52 +0300 Subject: [PATCH 1/5] fix: reconnect when lose master --- tests/advanced/test_replicaset.py | 28 +++++++++++++++++++++- txmongo/collection.py | 12 ++++++---- txmongo/database.py | 5 ++-- txmongo/protocol.py | 40 ++++++++++++++++++++++++------- txmongo/pymongo_errors.py | 34 ++++++++++++++++++++++++++ txmongo/pymongo_internals.py | 35 +-------------------------- 6 files changed, 104 insertions(+), 50 deletions(-) create mode 100644 txmongo/pymongo_errors.py diff --git a/tests/advanced/test_replicaset.py b/tests/advanced/test_replicaset.py index 801c01af..afbf8a17 100644 --- a/tests/advanced/test_replicaset.py +++ b/tests/advanced/test_replicaset.py @@ -17,7 +17,12 @@ from time import time from bson import SON -from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NotPrimaryError, + OperationFailure, +) from twisted.internet import defer, reactor from twisted.trial import unittest @@ -346,3 +351,24 @@ def test_StaleConnection(self): finally: self.__mongod[0].kill(signal.SIGCONT) yield conn.disconnect() + + @defer.inlineCallbacks + def test_close_connection_after_primary_step_down(self): + try: + conn = ConnectionPool(self.master_with_guaranteed_write) + + yield conn.db.coll.insert_one({"x": 42}) + + while True: + try: + yield conn.db.coll.find_one() + yield self.__sleep(1) + yield conn.admin.command( + SON([("replSetStepDown", 86400), ("force", 1)]) + ) + except NotPrimaryError: + break # this is what we should have returned + + finally: + yield conn.disconnect() + self.flushLoggedErrors(NotPrimaryError) diff --git a/txmongo/collection.py b/txmongo/collection.py index 772e9270..3b8abeec 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -524,15 +524,15 @@ def after_connection(proto): flags, ) - return proto.send_simple_msg(cmd, codec_options).addCallback( - after_reply, after_reply, proto - ) + return proto.send_simple_msg( + cmd, codec_options, flag_bits=Msg.create_flag_bits(False) + ).addCallback(after_reply, after_reply, proto) # this_func argument is just a reference to after_reply function itself. # after_reply can reference to itself directly but this will create a circular # reference between closure and function object which will add unnecessary # work for GC. - def after_reply(reply, this_func, proto, fetched=0): + def after_reply(reply: dict, this_func, proto, fetched=0): try: check_deadline(_deadline) except Exception: @@ -586,7 +586,9 @@ def after_reply(reply, this_func, proto, fetched=0): if batch_size: get_more["batchSize"] = batch_size - next_reply = proto.send_simple_msg(get_more, codec_options) + next_reply = proto.send_simple_msg( + get_more, codec_options, flag_bits=Msg.create_flag_bits(False) + ) next_reply.addCallback(this_func, this_func, proto, fetched) return out, next_reply diff --git a/txmongo/database.py b/txmongo/database.py index 7a4910b4..ea0963a4 100644 --- a/txmongo/database.py +++ b/txmongo/database.py @@ -5,7 +5,6 @@ from twisted.internet import defer from txmongo.collection import Collection -from txmongo.protocol import Msg from txmongo.pymongo_internals import _check_command_response from txmongo.utils import check_deadline, timeout @@ -78,7 +77,9 @@ def command( proto = yield self.connection.getprotocol() check_deadline(_deadline) - reply = yield proto.send_simple_msg(command, codec_options) + reply = yield proto.send_simple_msg( + command, codec_options, flag_bits=0, check=False + ) if check: msg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( repr(command), self diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 75f99343..3794f3a9 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -38,6 +38,8 @@ from twisted.internet import defer, error, protocol from twisted.python import failure, log +from txmongo.pymongo_errors import _NOT_MASTER_CODES + try: from pymongo.synchronous import auth except ImportError: @@ -527,18 +529,35 @@ def send_msg(self, msg: Msg) -> defer.Deferred[Msg]: return defer.succeed(None) return self.__wait_for_reply_to(request_id) + def _check_master(self, answer: dict): + if answer.get("ok") == 0: + if answer.get("code", -1) in _NOT_MASTER_CODES: + self.transport.loseConnection() + return NotPrimaryError( + "TxMongo: " + answer.get("errmsg", "Unknown error") + ) + def send_simple_msg( - self, body: dict, codec_options: CodecOptions + self, body: dict, codec_options: CodecOptions, flag_bits: int, check=True ) -> defer.Deferred[dict]: """Send simple OP_MSG without extracted payload and return parsed response.""" def on_response(response: Msg): reply = bson.decode(response.body, codec_options) for key, bin_docs in msg.payload.items(): - reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] + reply[key] = [] + for doc in bin_docs: + answer = bson.decode(doc, codec_options) + if check: + if master_error := self._check_master(answer): + reply[key].append(master_error) + break + reply[key].append(answer) return reply - msg = Msg(body=bson.encode(body, codec_options=codec_options)) + msg = Msg( + flag_bits=flag_bits, body=bson.encode(body, codec_options=codec_options) + ) return self.send_msg(msg).addCallback(on_response) def handle(self, request: BaseMessage): @@ -552,7 +571,7 @@ def handle(self, request: BaseMessage): logLevel=logging.WARNING, ) - def handle_reply(self, request): + def handle_reply(self, request: Reply): if request.response_to in self.__deferreds: df = self.__deferreds.pop(request.response_to) if request.response_flags & REPLY_QUERY_FAILURE: @@ -560,7 +579,8 @@ def handle_reply(self, request): code = doc.get("code") msg = "TxMongo: " + doc.get("$err", "Unknown error") fail_conn = False - if code == 13435: + + if code in _NOT_MASTER_CODES: err = NotPrimaryError(msg) fail_conn = True else: @@ -576,9 +596,13 @@ def handle_reply(self, request): else: df.callback(request) - def handle_msg(self, request: Msg): - if dfr := self.__deferreds.pop(request.response_to, None): - dfr.callback(request) + def handle_msg(self, msg: Msg): + if dfr := self.__deferreds.pop(msg.response_to, None): + answer = bson.decode(msg.body) + if master_error := self._check_master(answer): + dfr.errback(master_error) + else: + dfr.callback(msg) def set_wire_versions(self, min_wire_version, max_wire_version): self.min_wire_version = min_wire_version diff --git a/txmongo/pymongo_errors.py b/txmongo/pymongo_errors.py new file mode 100644 index 00000000..119634d6 --- /dev/null +++ b/txmongo/pymongo_errors.py @@ -0,0 +1,34 @@ +# Copied from pymongo/helpers.py:32 at commit d7d94b2776098dba32686ddf3ada1f201172daaf + +# From the SDAM spec, the "node is shutting down" codes. +_SHUTDOWN_CODES = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) +# From the SDAM spec, the "not master" error codes are combined with the +# "node is recovering" error codes (of which the "node is shutting down" +# errors are a subset). +_NOT_MASTER_CODES = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotMaster + 13435, # NotMasterNoSlaveOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotMasterOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) +# From the retryable writes spec. +_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + ] +) diff --git a/txmongo/pymongo_internals.py b/txmongo/pymongo_internals.py index 4fa3c646..0610e008 100644 --- a/txmongo/pymongo_internals.py +++ b/txmongo/pymongo_internals.py @@ -12,40 +12,7 @@ ) from txmongo._bulk import _DELETE, _INSERT, _UPDATE, _Run - -# Copied from pymongo/helpers.py:32 at commit d7d94b2776098dba32686ddf3ada1f201172daaf - -# From the SDAM spec, the "node is shutting down" codes. -_SHUTDOWN_CODES = frozenset( - [ - 11600, # InterruptedAtShutdown - 91, # ShutdownInProgress - ] -) -# From the SDAM spec, the "not master" error codes are combined with the -# "node is recovering" error codes (of which the "node is shutting down" -# errors are a subset). -_NOT_MASTER_CODES = ( - frozenset( - [ - 10107, # NotMaster - 13435, # NotMasterNoSlaveOk - 11602, # InterruptedDueToReplStateChange - 13436, # NotMasterOrSecondary - 189, # PrimarySteppedDown - ] - ) - | _SHUTDOWN_CODES -) -# From the retryable writes spec. -_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset( - [ - 7, # HostNotFound - 6, # HostUnreachable - 89, # NetworkTimeout - 9001, # SocketException - ] -) +from txmongo.pymongo_errors import _NOT_MASTER_CODES # Copied from pymongo/helpers.py:193 at commit 47b0d8ebfd6cefca80c1e4521b47aec7cf8f529d From 5cd4f28375e213270180d4831d3ed23e0aefc734 Mon Sep 17 00:00:00 2001 From: baranov Date: Tue, 22 Oct 2024 11:27:56 +0300 Subject: [PATCH 2/5] revert: send simple msg --- txmongo/collection.py | 10 ++++------ txmongo/database.py | 4 +--- txmongo/protocol.py | 22 +++++++++++++--------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/txmongo/collection.py b/txmongo/collection.py index 3b8abeec..565e3770 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -524,9 +524,9 @@ def after_connection(proto): flags, ) - return proto.send_simple_msg( - cmd, codec_options, flag_bits=Msg.create_flag_bits(False) - ).addCallback(after_reply, after_reply, proto) + return proto.send_simple_msg(cmd, codec_options).addCallback( + after_reply, after_reply, proto + ) # this_func argument is just a reference to after_reply function itself. # after_reply can reference to itself directly but this will create a circular @@ -586,9 +586,7 @@ def after_reply(reply: dict, this_func, proto, fetched=0): if batch_size: get_more["batchSize"] = batch_size - next_reply = proto.send_simple_msg( - get_more, codec_options, flag_bits=Msg.create_flag_bits(False) - ) + next_reply = proto.send_simple_msg(get_more, codec_options) next_reply.addCallback(this_func, this_func, proto, fetched) return out, next_reply diff --git a/txmongo/database.py b/txmongo/database.py index ea0963a4..5ce9e301 100644 --- a/txmongo/database.py +++ b/txmongo/database.py @@ -77,9 +77,7 @@ def command( proto = yield self.connection.getprotocol() check_deadline(_deadline) - reply = yield proto.send_simple_msg( - command, codec_options, flag_bits=0, check=False - ) + reply = yield proto.send_simple_msg(command, codec_options) if check: msg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( repr(command), self diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 3794f3a9..0bd473a7 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -538,7 +538,7 @@ def _check_master(self, answer: dict): ) def send_simple_msg( - self, body: dict, codec_options: CodecOptions, flag_bits: int, check=True + self, body: dict, codec_options: CodecOptions ) -> defer.Deferred[dict]: """Send simple OP_MSG without extracted payload and return parsed response.""" @@ -548,17 +548,21 @@ def on_response(response: Msg): reply[key] = [] for doc in bin_docs: answer = bson.decode(doc, codec_options) - if check: - if master_error := self._check_master(answer): - reply[key].append(master_error) - break + # if check: + # if master_error := self._check_master(answer): + # reply[key].append(master_error) + # break reply[key].append(answer) return reply - msg = Msg( - flag_bits=flag_bits, body=bson.encode(body, codec_options=codec_options) - ) - return self.send_msg(msg).addCallback(on_response) + def on_response_v1(response: Msg): + reply = bson.decode(response.body, codec_options) + for key, bin_docs in msg.payload.items(): + reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] + return reply + + msg = Msg(body=bson.encode(body, codec_options=codec_options)) + return self.send_msg(msg).addCallback(on_response_v1) def handle(self, request: BaseMessage): if isinstance(request, Reply): From a022b7827176ad8393add0aa77ba0b4292de17b3 Mon Sep 17 00:00:00 2001 From: Ilya Skriblovsky Date: Thu, 24 Oct 2024 08:19:09 +0300 Subject: [PATCH 3/5] Unify send_msg and send_simple_msg --- tests/advanced/test_replicaset.py | 25 +++-- tests/basic/test_bulk.py | 18 ++-- tests/basic/test_protocol.py | 20 ++-- tests/basic/test_queries.py | 5 +- tests/mongod.py | 2 +- txmongo/_bulk.py | 23 ++--- txmongo/_bulk_constants.py | 13 +++ txmongo/collection.py | 157 ++++++++++++------------------ txmongo/database.py | 17 ++-- txmongo/protocol.py | 118 +++++++++++++--------- txmongo/pymongo_internals.py | 9 +- 11 files changed, 203 insertions(+), 204 deletions(-) create mode 100644 txmongo/_bulk_constants.py diff --git a/tests/advanced/test_replicaset.py b/tests/advanced/test_replicaset.py index afbf8a17..516fdea7 100644 --- a/tests/advanced/test_replicaset.py +++ b/tests/advanced/test_replicaset.py @@ -17,12 +17,7 @@ from time import time from bson import SON -from pymongo.errors import ( - AutoReconnect, - ConfigurationError, - NotPrimaryError, - OperationFailure, -) +from pymongo.errors import AutoReconnect, ConfigurationError, NotPrimaryError from twisted.internet import defer, reactor from twisted.trial import unittest @@ -353,21 +348,23 @@ def test_StaleConnection(self): yield conn.disconnect() @defer.inlineCallbacks - def test_close_connection_after_primary_step_down(self): + def test_CloseConnectionAfterPrimaryStepDown(self): + conn = ConnectionPool(self.master_with_guaranteed_write) try: - conn = ConnectionPool(self.master_with_guaranteed_write) - yield conn.db.coll.insert_one({"x": 42}) + got_not_primary_error = False + while True: try: yield conn.db.coll.find_one() + if got_not_primary_error: + # We got error and then restored — OK + break yield self.__sleep(1) - yield conn.admin.command( - SON([("replSetStepDown", 86400), ("force", 1)]) - ) - except NotPrimaryError: - break # this is what we should have returned + yield conn.admin.command({"replSetStepDown": 86400, "force": 1}) + except (NotPrimaryError, AutoReconnect): + got_not_primary_error = True finally: yield conn.disconnect() diff --git a/tests/basic/test_bulk.py b/tests/basic/test_bulk.py index 05db828a..d38e6071 100644 --- a/tests/basic/test_bulk.py +++ b/tests/basic/test_bulk.py @@ -269,20 +269,18 @@ def test_OperationFailure(self): def fake_send_query(*args): return defer.succeed( - Msg( - body=bson.encode( - { - "ok": 0.0, - "errmsg": "operation was interrupted", - "code": 11602, - "codeName": "InterruptedDueToReplStateChange", - } - ) + Msg.create( + { + "ok": 0.0, + "errmsg": "operation was interrupted", + "code": 11602, + "codeName": "InterruptedDueToReplStateChange", + } ) ) with patch( - "txmongo.protocol.MongoProtocol.send_msg", side_effect=fake_send_query + "txmongo.protocol.MongoProtocol._send_raw_msg", side_effect=fake_send_query ): yield self.assertFailure( self.coll.bulk_write( diff --git a/tests/basic/test_protocol.py b/tests/basic/test_protocol.py index 0a5d0180..9947f771 100644 --- a/tests/basic/test_protocol.py +++ b/tests/basic/test_protocol.py @@ -87,24 +87,24 @@ def test_EncodeDecodeReply(self): self.assertEqual(decoded.documents, request.documents) def test_EncodeDecodeMsg(self): - request = Msg( - response_to=123, - flag_bits=OP_MSG_MORE_TO_COME, - body=bson.encode({"a": 1, "$db": "dbname"}), + request = Msg.create( + body={"a": 1, "$db": "dbname"}, payload={ "documents": [ - bson.encode({"a": 1}), - bson.encode({"a": 2}), + {"a": 1}, + {"a": 2}, ], "updates": [ - bson.encode({"$set": {"z": 1}}), - bson.encode({"$set": {"z": 2}}), + {"$set": {"z": 1}}, + {"$set": {"z": 2}}, ], "deletes": [ - bson.encode({"_id": ObjectId()}), - bson.encode({"_id": ObjectId()}), + {"_id": ObjectId()}, + {"_id": ObjectId()}, ], }, + acknowledged=False, + response_to=123, ) decoded = self._encode_decode(request) diff --git a/tests/basic/test_queries.py b/tests/basic/test_queries.py index f48083c5..52ba4655 100644 --- a/tests/basic/test_queries.py +++ b/tests/basic/test_queries.py @@ -203,10 +203,7 @@ def test_CursorClosingWithTimeout(self): {"$where": "sleep(100); true"}, batch_size=5, timeout=0.8 ) with patch.object( - MongoProtocol, - "send_msg", - side_effect=MongoProtocol.send_msg, - autospec=True, + MongoProtocol, "send_msg", side_effect=MongoProtocol.send_msg, autospec=True ) as mock: with self.assertRaises(TimeExceeded): yield dfr diff --git a/tests/mongod.py b/tests/mongod.py index b455c88a..f60d03ac 100644 --- a/tests/mongod.py +++ b/tests/mongod.py @@ -96,7 +96,7 @@ def stop(self): if self._proc and self._proc.pid: d = defer.Deferred() self._notify_stop.append(d) - self._proc.signalProcess("INT") + self.kill("INT") return d else: return defer.fail("Not started yet") diff --git a/txmongo/_bulk.py b/txmongo/_bulk.py index c0e1a975..a63b84d9 100644 --- a/txmongo/_bulk.py +++ b/txmongo/_bulk.py @@ -18,27 +18,18 @@ validate_ok_for_update, ) +from txmongo._bulk_constants import ( + _DELETE, + _INSERT, + _UPDATE, + COMMAND_NAME, + PAYLOAD_ARG_NAME, +) from txmongo.protocol import MongoProtocol, Msg from txmongo.types import Document _WriteOp = Union[InsertOne, UpdateOne, UpdateMany, ReplaceOne, DeleteOne, DeleteMany] -_INSERT = 0 -_UPDATE = 1 -_DELETE = 2 - -COMMAND_NAME = { - _INSERT: "insert", - _UPDATE: "update", - _DELETE: "delete", -} - -PAYLOAD_ARG_NAME = { - _INSERT: "documents", - _UPDATE: "updates", - _DELETE: "deletes", -} - class _Run: op_type: int diff --git a/txmongo/_bulk_constants.py b/txmongo/_bulk_constants.py new file mode 100644 index 00000000..1d95483d --- /dev/null +++ b/txmongo/_bulk_constants.py @@ -0,0 +1,13 @@ +_INSERT = 0 +_UPDATE = 1 +_DELETE = 2 +COMMAND_NAME = { + _INSERT: "insert", + _UPDATE: "update", + _DELETE: "delete", +} +PAYLOAD_ARG_NAME = { + _INSERT: "documents", + _UPDATE: "updates", + _DELETE: "deletes", +} diff --git a/txmongo/collection.py b/txmongo/collection.py index 565e3770..7d1aa7f8 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -6,7 +6,6 @@ from operator import itemgetter from typing import Iterable, List, Optional -import bson from bson import ObjectId from bson.codec_options import CodecOptions from bson.son import SON @@ -32,19 +31,10 @@ from twisted.python.compat import comparable from txmongo import filter as qf -from txmongo._bulk import _INSERT, _Bulk, _Run -from txmongo.protocol import ( - OP_MSG_MORE_TO_COME, - QUERY_PARTIAL, - QUERY_SLAVE_OK, - MongoProtocol, - Msg, -) -from txmongo.pymongo_internals import ( - _check_command_response, - _check_write_command_response, - _merge_command, -) +from txmongo._bulk import _Bulk, _Run +from txmongo._bulk_constants import _INSERT +from txmongo.protocol import QUERY_PARTIAL, QUERY_SLAVE_OK, MongoProtocol, Msg +from txmongo.pymongo_internals import _check_write_command_response, _merge_command from txmongo.types import Document from txmongo.utils import check_deadline, timeout @@ -413,9 +403,8 @@ def query(): "$showDiskLoc": "showRecordId", # <= MongoDB 3.0 } - @classmethod def _gen_find_command( - cls, + self, db_name: str, coll_name: str, filter_with_modifiers, @@ -425,12 +414,16 @@ def _gen_find_command( batch_size, allow_partial_results, flags: int, - ): + ) -> Msg: cmd = {"find": coll_name} if "$query" in filter_with_modifiers: cmd.update( [ - (cls._MODIFIERS[key], val) if key in cls._MODIFIERS else (key, val) + ( + (self._MODIFIERS[key], val) + if key in self._MODIFIERS + else (key, val) + ) for key, val in filter_with_modifiers.items() ] ) @@ -459,19 +452,17 @@ def _gen_find_command( cmd = {"explain": cmd} cmd["$db"] = db_name - return cmd + return Msg.create(cmd, codec_options=self.codec_options) def __close_cursor_without_response(self, proto: MongoProtocol, cursor_id: int): proto.send_msg( - Msg( - flag_bits=OP_MSG_MORE_TO_COME, - body=bson.encode( - { - "killCursors": self.name, - "$db": self._database.name, - "cursors": [cursor_id], - }, - ), + Msg.create( + { + "killCursors": self.name, + "$db": self._database.name, + "cursors": [cursor_id], + }, + acknowledged=False, ) ) @@ -524,7 +515,7 @@ def after_connection(proto): flags, ) - return proto.send_simple_msg(cmd, codec_options).addCallback( + return proto.send_msg(cmd, codec_options).addCallback( after_reply, after_reply, proto ) @@ -541,8 +532,6 @@ def after_reply(reply: dict, this_func, proto, fetched=0): self.__close_cursor_without_response(proto, cursor_id) raise - _check_command_response(reply) - if "cursor" not in reply: # For example, when we run `explain` command return [reply], defer.succeed(([], None)) @@ -586,7 +575,9 @@ def after_reply(reply: dict, this_func, proto, fetched=0): if batch_size: get_more["batchSize"] = batch_size - next_reply = proto.send_simple_msg(get_more, codec_options) + next_reply = proto.send_msg( + Msg.create(get_more, codec_options=codec_options), codec_options + ) next_reply.addCallback(this_func, this_func, proto, fetched) return out, next_reply @@ -697,26 +688,23 @@ def _insert_one( document["_id"] = ObjectId() inserted_id = document["_id"] - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode( - { - "insert": self.name, - "$db": self.database.name, - "writeConcern": self.write_concern.document, - } - ), - payload={ - "documents": [bson.encode(document, codec_options=self.codec_options)], + msg = Msg.create( + { + "insert": self.name, + "$db": self.database.name, + "writeConcern": self.write_concern.document, }, + { + "documents": [document], + }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response: Optional[Msg] = yield proto.send_msg(msg) - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply: Optional[dict] = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) return InsertOneResult(inserted_id, self.write_concern.acknowledged) @@ -761,37 +749,30 @@ def gen(): @defer.inlineCallbacks def _update(self, filter, update, upsert, multi, _deadline): - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode( - { - "update": self.name, - "$db": self.database.name, - "writeConcern": self.write_concern.document, - } - ), - payload={ + msg = Msg.create( + { + "update": self.name, + "$db": self.database.name, + "writeConcern": self.write_concern.document, + }, + { "updates": [ - bson.encode( - { - "q": filter, - "u": update, - "upsert": bool(upsert), - "multi": bool(multi), - }, - codec_options=self.codec_options, - ) + { + "q": filter, + "u": update, + "upsert": bool(upsert), + "multi": bool(multi), + } ], }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response = yield proto.send_msg(msg) - reply = None - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) if reply.get("n") and "upserted" in reply: # MongoDB >= 2.6.0 returns the upsert _id in an array @@ -916,29 +897,24 @@ def _delete( if let: body["let"] = let - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode(body), - payload={ + msg = Msg.create( + body, + { "deletes": [ - bson.encode( - { - "q": filter, - "limit": 0 if multi else 1, - }, - codec_options=self.codec_options, - ) + { + "q": filter, + "limit": 0 if multi else 1, + }, ], }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response = yield proto.send_msg(msg) - reply = None - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) return DeleteResult(reply, self.write_concern.acknowledged) @@ -1256,13 +1232,9 @@ def _execute_bulk(self, bulk: _Bulk, _deadline: Optional[float]): } def accumulate_response(response: dict, run: _Run, idx_offset: int) -> dict: - _check_command_response(response) _merge_command(run, full_result, idx_offset, response) return response - def decode_response(response: Msg, codec_options: CodecOptions) -> dict: - return bson.decode(response.body, codec_options=codec_options) - got_error = False for run in bulk.gen_runs(): for doc_offset, msg in run.gen_messages( @@ -1273,9 +1245,8 @@ def decode_response(response: Msg, codec_options: CodecOptions) -> dict: self.codec_options, ): check_deadline(_deadline) - deferred = proto.send_msg(msg) + deferred = proto.send_msg(msg, self.codec_options) if effective_write_concern.acknowledged: - deferred.addCallback(decode_response, self.codec_options) if bulk.ordered: reply = yield deferred accumulate_response(reply, run, doc_offset) diff --git a/txmongo/database.py b/txmongo/database.py index 5ce9e301..be1147bb 100644 --- a/txmongo/database.py +++ b/txmongo/database.py @@ -5,6 +5,7 @@ from twisted.internet import defer from txmongo.collection import Collection +from txmongo.protocol import Msg from txmongo.pymongo_internals import _check_command_response from txmongo.utils import check_deadline, timeout @@ -77,12 +78,16 @@ def command( proto = yield self.connection.getprotocol() check_deadline(_deadline) - reply = yield proto.send_simple_msg(command, codec_options) - if check: - msg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( - repr(command), self - ) - _check_command_response(reply, msg, allowable_errors) + errmsg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( + repr(command), self + ) + reply = yield proto.send_msg( + Msg.create(command, codec_options=codec_options), + codec_options, + check=check, + errmsg=errmsg, + allowable_errors=allowable_errors, + ) return reply @timeout diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 0bd473a7..251be06a 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -23,10 +23,10 @@ from dataclasses import dataclass, field from hashlib import sha1 from random import SystemRandom -from typing import Dict, List +from typing import Dict, List, Optional import bson -from bson import SON, Binary, CodecOptions +from bson import DEFAULT_CODEC_OPTIONS, SON, Binary, CodecOptions from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -39,6 +39,8 @@ from twisted.python import failure, log from txmongo.pymongo_errors import _NOT_MASTER_CODES +from txmongo.pymongo_internals import _check_command_response +from txmongo.types import Document try: from pymongo.synchronous import auth @@ -256,6 +258,35 @@ def opcode(cls): def create_flag_bits(cls, not_more_to_come: bool) -> int: return 0 if not_more_to_come else OP_MSG_MORE_TO_COME + @classmethod + def create( + cls, + body: Document, + payload: Dict[str, List[Document]] = None, + *, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + acknowledged: bool = True, + request_id: int = 0, + response_to: int = 0, + ) -> "Msg": + encoded_payload = {} + if payload: + encoded_payload = { + key: [bson.encode(doc, codec_options=codec_options) for doc in docs] + for key, docs in payload.items() + } + return Msg( + request_id=request_id, + response_to=response_to, + body=bson.encode(body, codec_options=codec_options), + flag_bits=0 if acknowledged else OP_MSG_MORE_TO_COME, + payload=encoded_payload, + ) + + @property + def acknowledged(self) -> bool: + return (self.flag_bits & OP_MSG_MORE_TO_COME) == 0 + def size_in_bytes(self) -> int: """return estimated overall message length including messageLength and requestID""" # checksum is not added since we don't support it for now @@ -519,50 +550,45 @@ def send_query(self, request): request_id = self._send(request) return self.__wait_for_reply_to(request_id) - def send_msg(self, msg: Msg) -> defer.Deferred[Msg]: - """Send Msg (OP_MSG) and return deferred. - - If OP_MSG has OP_MSG_MORE_TO_COME flag set, returns already fired deferred with None as a result. - """ + def _send_raw_msg(self, msg: Msg) -> defer.Deferred[Optional[Msg]]: + """Send OP_MSG and return result as Msg object (or None if not acknowledged)""" request_id = self._send(msg) - if msg.flag_bits & OP_MSG_MORE_TO_COME: - return defer.succeed(None) - return self.__wait_for_reply_to(request_id) + if msg.acknowledged: + return self.__wait_for_reply_to(request_id) + return defer.succeed(None) - def _check_master(self, answer: dict): - if answer.get("ok") == 0: - if answer.get("code", -1) in _NOT_MASTER_CODES: + @defer.inlineCallbacks + def send_msg( + self, + msg: Msg, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + *, + check: bool = True, + errmsg: str = None, + allowable_errors=None, + ) -> defer.Deferred[Optional[dict]]: + """Send OP_MSG and return parsed response as dict.""" + + response = yield self._send_raw_msg(msg) + if response is None: + return + + reply = bson.decode(response.body, codec_options) + for key, bin_docs in msg.payload.items(): + reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] + + if reply.get("ok") == 0: + if reply.get("code") in _NOT_MASTER_CODES: self.transport.loseConnection() - return NotPrimaryError( - "TxMongo: " + answer.get("errmsg", "Unknown error") + raise NotPrimaryError( + "TxMongo: " + reply.get("errmsg", "Unknown error") ) - def send_simple_msg( - self, body: dict, codec_options: CodecOptions - ) -> defer.Deferred[dict]: - """Send simple OP_MSG without extracted payload and return parsed response.""" - - def on_response(response: Msg): - reply = bson.decode(response.body, codec_options) - for key, bin_docs in msg.payload.items(): - reply[key] = [] - for doc in bin_docs: - answer = bson.decode(doc, codec_options) - # if check: - # if master_error := self._check_master(answer): - # reply[key].append(master_error) - # break - reply[key].append(answer) - return reply - - def on_response_v1(response: Msg): - reply = bson.decode(response.body, codec_options) - for key, bin_docs in msg.payload.items(): - reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] - return reply - - msg = Msg(body=bson.encode(body, codec_options=codec_options)) - return self.send_msg(msg).addCallback(on_response_v1) + if check: + _check_command_response( + reply, msg=errmsg, allowable_errors=allowable_errors + ) + return reply def handle(self, request: BaseMessage): if isinstance(request, Reply): @@ -600,13 +626,9 @@ def handle_reply(self, request: Reply): else: df.callback(request) - def handle_msg(self, msg: Msg): - if dfr := self.__deferreds.pop(msg.response_to, None): - answer = bson.decode(msg.body) - if master_error := self._check_master(answer): - dfr.errback(master_error) - else: - dfr.callback(msg) + def handle_msg(self, request: Msg): + if dfr := self.__deferreds.pop(request.response_to, None): + dfr.callback(request) def set_wire_versions(self, min_wire_version, max_wire_version): self.min_wire_version = min_wire_version diff --git a/txmongo/pymongo_internals.py b/txmongo/pymongo_internals.py index 0610e008..503df823 100644 --- a/txmongo/pymongo_internals.py +++ b/txmongo/pymongo_internals.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping, MutableMapping, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional from pymongo.errors import ( CursorNotFound, @@ -11,9 +13,12 @@ WTimeoutError, ) -from txmongo._bulk import _DELETE, _INSERT, _UPDATE, _Run +from txmongo._bulk_constants import _DELETE, _INSERT, _UPDATE from txmongo.pymongo_errors import _NOT_MASTER_CODES +if TYPE_CHECKING: + from txmongo._bulk import _Run + # Copied from pymongo/helpers.py:193 at commit 47b0d8ebfd6cefca80c1e4521b47aec7cf8f529d def _raise_last_write_error(write_errors): From 40c4f60b51fe27bbcea6efce72e9980924ce3c11 Mon Sep 17 00:00:00 2001 From: baranov Date: Thu, 24 Oct 2024 10:46:33 +0300 Subject: [PATCH 4/5] fix: test drop database when AutoReconnect is raised --- tests/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index 94f9222a..a419ac29 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +from pymongo.errors import AutoReconnect from twisted.internet import defer from twisted.trial import unittest @@ -15,5 +16,11 @@ def setUp(self): @defer.inlineCallbacks def tearDown(self): - yield self.coll.drop() + while True: + try: + yield self.coll.drop() + break + except AutoReconnect: + yield self.coll.drop() + yield self.conn.disconnect() From 63f3ccb16ca392cc32e480b3f5e2dc5926b8f9f2 Mon Sep 17 00:00:00 2001 From: baranov Date: Thu, 24 Oct 2024 11:04:55 +0300 Subject: [PATCH 5/5] refactor: test drop database when AutoReconnect is raised --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index a419ac29..2240af41 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,6 +21,6 @@ def tearDown(self): yield self.coll.drop() break except AutoReconnect: - yield self.coll.drop() + pass yield self.conn.disconnect()