Skip to content

Commit

Permalink
cache models.get_original_user/object_key queries in memcache
Browse files Browse the repository at this point in the history
adds new generic common.memcache_memoize decorator for caching any function's output in memcache. for #1149
  • Loading branch information
snarfed committed Dec 12, 2024
1 parent 720d6d3 commit 5d6d68b
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 13 deletions.
1 change: 1 addition & 0 deletions app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ env_variables:
BGS_HOST: bsky.network
MOD_SERVICE_HOST: mod.bsky.app
MOD_SERVICE_DID: did:plc:ar7c4by46qjdydhdevvrndac
MEMCACHE_HOST: '10.126.144.3'
# ...or test against labeler.dholms.xyz / did:plc:vzxheqfwpbi3lxbgdh22js66

handlers:
Expand Down
1 change: 1 addition & 0 deletions atproto_hub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ env_variables:
MOD_SERVICE_HOST: mod.bsky.app
MOD_SERVICE_DID: did:plc:ar7c4by46qjdydhdevvrndac
# ...or test against labeler.dholms.xyz / did:plc:vzxheqfwpbi3lxbgdh22js66
MEMCACHE_HOST: '10.126.144.3'

ROLLBACK_WINDOW: 200000
SUBSCRIBE_REPOS_BATCH_DELAY: 10
Expand Down
36 changes: 24 additions & 12 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import timedelta
import functools
import logging
import os
from pathlib import Path
import re
import threading
Expand All @@ -26,6 +27,7 @@
from oauth_dropins.webutil.util import json_dumps
from negotiator import ContentNegotiator, AcceptParameters, ContentType
import pymemcache.client.base
from pymemcache.serde import PickleSerde
from pymemcache.test.utils import MockMemcacheClient

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,13 +104,17 @@

if appengine_info.DEBUG or appengine_info.LOCAL_SERVER:
logger.info('Using in memory mock memcache')
memcache = MockMemcacheClient()
memcache = MockMemcacheClient(allow_unicode_keys=True)
pickle_memcache = MockMemcacheClient(allow_unicode_keys=True, serde=PickleSerde())
global_cache = _InProcessGlobalCache()
else:
logger.info('Using production Memorystore memcache')
memcache = pymemcache.client.base.PooledClient(
'10.126.144.3', timeout=10, connect_timeout=10, # seconds
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
allow_unicode_keys=True)
pickle_memcache = pymemcache.client.base.PooledClient(
os.environ['MEMCACHE_HOST'], timeout=10, connect_timeout=10, # seconds
serde=PickleSerde(), allow_unicode_keys=True)
global_cache = MemcacheCache(memcache)

_negotiator = ContentNegotiator(acceptable=[
Expand Down Expand Up @@ -299,7 +305,7 @@ def webmention_endpoint_cache_key(url):
if parsed.path in ('', '/'):
key += ' /'

# logger.debug(f'wm cache key {key}')
logger.debug(f'wm cache key {key}')
return key


Expand Down Expand Up @@ -478,27 +484,33 @@ def memcache_key(key):
return key[:MEMCACHE_KEY_MAX_LEN].replace(' ', '%20').encode()


def memcache_memoize(expire=None):
"""Memoize function decorator that stores the cached value in memcache.
def memcache_memoize_key(fn, *args, **kwargs):
return memcache_key(f'{fn.__name__}-2-{repr(args)}-{repr(kwargs)}')


NOT YET WORKING! CURRENTLY UNUSED!
NONE = () # empty tuple

Only caches non-null/empty values.
def memcache_memoize(expire=None):
"""Memoize function decorator that stores the cached value in memcache.
Args:
expire (int): optional, expiration in seconds
expire (timedelta): optional, expiration
"""
if expire:
expire = int(expire.total_seconds())

def decorator(fn):
@functools.wraps(fn)
def wrapped(*args, **kwargs):
key = memcache_key(f'{fn.__name__}-{repr(args)}-{repr(kwargs)}')
if val := memcache.get(key):
key = memcache_memoize_key(fn, *args, **kwargs)
val = pickle_memcache.get(key)
if val is not None:
logger.debug(f'cache hit {key}')
return val
return None if val == NONE else val

logger.debug(f'cache miss {key}')
val = fn(*args, **kwargs)
memcache.set(key, val)
pickle_memcache.set(key, NONE if val is None else val, expire=expire)
return val

return wrapped
Expand Down
31 changes: 31 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
base64_to_long,
DOMAIN_RE,
long_to_base64,
memcache_memoize,
OLD_ACCOUNT_AGE,
report_error,
unwrap,
Expand Down Expand Up @@ -81,6 +82,8 @@
)
OBJECT_EXPIRE_AGE = timedelta(days=90)

GET_ORIGINALS_CACHE_EXPIRATION = timedelta(days=1)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -233,6 +236,8 @@ def __init__(self, **kwargs):
if obj:
self.obj = obj

self.lock = Lock()

@classmethod
def new(cls, **kwargs):
"""Try to prevent instantiation. Use subclasses instead."""
Expand All @@ -241,6 +246,20 @@ def new(cls, **kwargs):
def _post_put_hook(self, future):
logger.debug(f'Wrote {self.key}')

def add(self, prop, val):
"""Adds a value to a multiply-valued property. Uses ``self.lock``.
Args:
prop (str)
val
"""
with self.lock:
util.add(getattr(self, prop), val)

if prop == 'copies':
common.pickle_memcache.set(common.memcache_memoize_key(
get_original_user_key, val.uri), self.key)

@classmethod
def get_by_id(cls, id, allow_opt_out=False, **kwargs):
"""Override to follow ``use_instead`` property and ``opt-out` status.
Expand Down Expand Up @@ -1098,6 +1117,10 @@ def add(self, prop, val):
with self.lock:
util.add(getattr(self, prop), val)

if prop == 'copies':
common.pickle_memcache.set(common.memcache_memoize_key(
get_original_object_key, val.uri), self.key)

def remove(self, prop, val):
"""Removes a value from a multiply-valued property. Uses ``self.lock``.
Expand Down Expand Up @@ -1659,9 +1682,13 @@ def get_paging_param(param):


@lru_cache(maxsize=100000)
@memcache_memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
def get_original_object_key(copy_id):
"""Finds the :class:`Object` with a given copy id, if any.
Note that :meth:`Object.add` also updates this function's
:func:`memcache_memoize` cache.
Args:
copy_id (str)
Expand All @@ -1674,9 +1701,13 @@ def get_original_object_key(copy_id):


@lru_cache(maxsize=100000)
@memcache_memoize(expire=GET_ORIGINALS_CACHE_EXPIRATION)
def get_original_user_key(copy_id):
"""Finds the user with a given copy id, if any.
Note that :meth:`User.add` also updates this function's
:func:`memcache_memoize` cache.
Args:
copy_id (str)
not_proto (Protocol): optional, don't query this protocol
Expand Down
1 change: 1 addition & 0 deletions router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ env_variables:
# https://bsky.app/profile/gargaj.umlaut.hu/post/3kxsvpqiuln26
CHAT_HOST: api.bsky.chat
CHAT_DID: did:web:api.bsky.chat
MEMCACHE_HOST: '10.126.144.3'

automatic_scaling:
min_num_instances: 2
Expand Down
84 changes: 83 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock, patch

import flask
from google.cloud.ndb import Key
from granary import as2
from oauth_dropins.webutil.appengine_config import error_reporting_client

Expand All @@ -13,7 +14,7 @@
import common
from arroba.datastore_storage import AtpBlock
from flask_app import app
from models import Follower, Object
from models import Follower, Object, Target
from ui import UIProtocol
from web import Web

Expand Down Expand Up @@ -174,6 +175,87 @@ def test_memcache_key(self):
):
self.assertEqual(expected, common.memcache_key(input))

def test_memcache_memoize(self):
calls = []

@common.memcache_memoize()
def foo(x, y, z=None):
calls.append((x, y, z))
return len(calls)

self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual([(1, 'a', 1)], calls)

self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)
self.assertEqual(1, foo(1, 'a', z=1))
self.assertEqual(2, foo(2, 'b', z=2))
self.assertEqual([(1, 'a', 1), (2, 'b', 2)], calls)

# def test_memcache_memoize_Object(self):
# calls = []

# obj = Object(users=[Key(Object, 'abc')],
# copies=[Target(uri='abc', protocol='web')],
# as2={'foo': 'x ☕ y', 'bar': True, 'baz': 5})

# @common.memcache_memoize()
# def foo(x):
# calls.append(x)
# obj.key = Key(Object, x)
# return obj

# expected_a = Object(id='a', **obj.to_dict(include=['users', 'copies', 'as2']))
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a'], calls)
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a'], calls)

# expected_b = Object(id='b', **obj.to_dict(include=['users', 'copies', 'as2']))
# self.assert_entities_equal(expected_b, foo('b'))
# self.assertEqual(['a', 'b'], calls)
# self.assert_entities_equal(expected_a, foo('a'))
# self.assertEqual(['a', 'b'], calls)
# self.assert_entities_equal(expected_b, foo('b'))
# self.assertEqual(['a', 'b'], calls)

def test_memcache_memoize_Key(self):
calls = []

@common.memcache_memoize()
def foo(x):
calls.append(x)
return Key(Object, x)

a = Key(Object, 'a')
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a'], calls)

b = Key(Object, 'b')
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(a, foo('a'))
self.assertEqual(['a', 'b'], calls)
self.assertEqual(b, foo('b'))
self.assertEqual(['a', 'b'], calls)

def test_memcache_memoize_None(self):
calls = []

@common.memcache_memoize()
def foo(x):
calls.append(x)
return None

self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)
self.assertIsNone(foo('a'))
self.assertEqual(['a'], calls)

def test_as2_request_type(self):
for accept, expected in (
(as2.CONTENT_TYPE_LD_PROFILE, as2.CONTENT_TYPE_LD_PROFILE),
Expand Down
27 changes: 27 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,17 @@ def test_is_enabled_protocol_bot_users(self):
self.assertFalse(Web(id='ap.brid.gy').is_enabled(ActivityPub))
self.assertFalse(Web(id='bsky.brid.gy').is_enabled(ATProto))

def test_add_to_copies_updates_memcache(self):
cache_key = common.memcache_memoize_key(
models.get_original_user_key, 'other:x')
self.assertIsNone(common.pickle_memcache.get(cache_key))

user = Fake(id='fake:x')
copy = Target(protocol='other', uri='other:x')
user.add('copies', copy)

self.assertEqual(user.key, common.pickle_memcache.get(cache_key))


class ObjectTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -1010,6 +1021,7 @@ def test_resolve_ids_copies_follow(self):

models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()

# matching copy users
self.make_user('other:alice', cls=OtherFake,
Expand Down Expand Up @@ -1048,6 +1060,7 @@ def test_resolve_ids_copies_reply(self):

models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()

# matching copies
self.make_user('other:alice', cls=OtherFake,
Expand Down Expand Up @@ -1089,6 +1102,7 @@ def test_resolve_ids_multiple_in_reply_to(self):

models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()

# matching copies
self.store_object(id='other:a',
Expand Down Expand Up @@ -1202,13 +1216,15 @@ def test_normalize_ids_reply(self):
def test_get_original_user_key(self):
self.assertIsNone(models.get_original_user_key('other:user'))
models.get_original_user_key.cache_clear()
common.pickle_memcache.clear()
user = self.make_user('fake:user', cls=Fake,
copies=[Target(uri='other:user', protocol='other')])
self.assertEqual(user.key, models.get_original_user_key('other:user'))

def test_get_original_object_key(self):
self.assertIsNone(models.get_original_object_key('other:post'))
models.get_original_object_key.cache_clear()
common.pickle_memcache.clear()
obj = self.store_object(id='fake:post',
copies=[Target(uri='other:post', protocol='other')])
self.assertEqual(obj.key, models.get_original_object_key('other:post'))
Expand All @@ -1226,6 +1242,17 @@ def test_get_copy(self):
obj.copies.append(Target(uri='fake:foo', protocol='fake'))
self.assertEqual('fake:foo', obj.get_copy(Fake))

def test_add_to_copies_updates_memcache(self):
cache_key = common.memcache_memoize_key(
models.get_original_object_key, 'other:x')
self.assertIsNone(common.pickle_memcache.get(cache_key))

obj = Object(id='x')
copy = Target(protocol='other', uri='other:x')
obj.add('copies', copy)

self.assertEqual(obj.key, common.pickle_memcache.get(cache_key))


class FollowerTest(TestCase):

Expand Down
Loading

0 comments on commit 5d6d68b

Please sign in to comment.