diff --git a/diskcache/core.py b/diskcache/core.py index c7c8486..ebf9a75 100644 --- a/diskcache/core.py +++ b/diskcache/core.py @@ -707,44 +707,48 @@ def transact(self, retry=False): @cl.contextmanager def _transact(self, retry=False, filename=None): - sql = self._sql - filenames = [] - _disk_remove = self._disk.remove - tid = threading.get_ident() - txn_id = self._txn_id + _acquireLock() + try: + sql = self._sql + filenames = [] + _disk_remove = self._disk.remove + tid = threading.get_ident() + txn_id = self._txn_id - if tid == txn_id: - begin = False - else: - while True: - try: - sql('BEGIN IMMEDIATE') - begin = True - self._txn_id = tid - break - except sqlite3.OperationalError: - if retry: - continue - if filename is not None: - _disk_remove(filename) - raise Timeout from None + if tid == txn_id: + begin = False + else: + while True: + try: + sql('BEGIN IMMEDIATE') + begin = True + self._txn_id = tid + break + except sqlite3.OperationalError: + if retry: + continue + if filename is not None: + _disk_remove(filename) + raise Timeout from None - try: - yield sql, filenames.append - except BaseException: - if begin: - assert self._txn_id == tid - self._txn_id = None - sql('ROLLBACK') - raise - else: - if begin: - assert self._txn_id == tid - self._txn_id = None - sql('COMMIT') - for name in filenames: - if name is not None: - _disk_remove(name) + try: + yield sql, filenames.append + except BaseException: + if begin: + assert self._txn_id == tid + self._txn_id = None + sql('ROLLBACK') + raise + else: + if begin: + assert self._txn_id == tid + self._txn_id = None + sql('COMMIT') + for name in filenames: + if name is not None: + _disk_remove(name) + finally: + _releaseLock() def set(self, key, value, expire=None, read=False, tag=None, retry=False): """Set `key` and `value` item in cache. @@ -2450,3 +2454,32 @@ def reset(self, key, value=ENOVAL, update=True): setattr(self, key, value) return value + +if hasattr(os, 'register_at_fork'): + _lock = threading.RLock() + + def _acquireLock(): + global _lock + try: + _lock.acquire() + except BaseException: + _lock.release() + raise + + def _releaseLock(): + global _lock + _lock.release() + + def _after_at_fork_child_reinit_locks(): + global _lock + _lock = threading.RLock() + + os.register_at_fork(before=_acquireLock, + after_in_child=_after_at_fork_child_reinit_locks, + after_in_parent=_releaseLock) +else: + def _acquireLock(): + pass + + def _releaseLock(): + pass diff --git a/tests/test_fork_multithreading.py b/tests/test_fork_multithreading.py new file mode 100644 index 0000000..00a47ca --- /dev/null +++ b/tests/test_fork_multithreading.py @@ -0,0 +1,71 @@ +""" +Test diskcache.core.Cache behaviour when process is forking. +Make sure it does not deadlock on the sqlite3 transaction lock if +forked while the lock is in use. +""" + +import errno +import hashlib +import io +import os +import os.path as op +import sys +import pathlib +import pickle +import shutil +import sqlite3 +import subprocess as sp +import tempfile +import threading +import time +import warnings +from threading import Thread +from unittest import mock + +if sys.platform != "win32": + import signal + +import pytest + +import diskcache as dc + +REPEATS = 1000 + +@pytest.fixture +def cache(): + with dc.Cache() as cache: + yield cache + shutil.rmtree(cache.directory, ignore_errors=True) + +def _test_thread_imp(cache): + for i in range(REPEATS * 10): + cache.set(i, i) + +def _test_wait_pid(pid): + _, status = os.waitpid(pid, 0) + assert status == 0, "Child died unexpectedly" + +@pytest.mark.skipif(sys.platform == "win32", reason="no fork on Windows") +def test_fork_multithreading(cache): + thread = Thread(target=_test_thread_imp, args=(cache,)) + thread.start() + try: + for i in range(REPEATS): + pid = os.fork() + if pid == 0: + cache.set(i, 0) + os._exit(0) + else: + thread = Thread(target=_test_wait_pid, args=(pid,)) + thread.start() + thread.join(timeout=10) + if thread.is_alive(): + os.kill(pid, signal.SIGKILL) + thread.join() + assert False, "Deadlock detected." + except OSError as e: + if e.errno != errno.EINTR: + raise + + thread.join() +