diff --git a/app/caches.py b/app/caches.py index df95f508..110b2ccb 100644 --- a/app/caches.py +++ b/app/caches.py @@ -1,52 +1,77 @@ """app.caches.py""" import functools import logging -from typing import Union - import aiocache -from .config import get_settings +from abc import ABC, abstractmethod LOGGER = logging.getLogger(name="app.caches") -SETTINGS = get_settings() - -if SETTINGS.rediscloud_url: - REDIS_URL = SETTINGS.rediscloud_url - LOGGER.info("Using Rediscloud") -else: - REDIS_URL = SETTINGS.local_redis_url - LOGGER.info("Using Local Redis") - - -@functools.lru_cache() -def get_cache(namespace) -> Union[aiocache.RedisCache, aiocache.SimpleMemoryCache]: - """Retunr """ - if REDIS_URL: - LOGGER.info("using RedisCache") - return aiocache.RedisCache( - endpoint=REDIS_URL.host, - port=REDIS_URL.port, - password=REDIS_URL.password, - namespace=namespace, +class Caches(ABC): + + @property + def cache(self): + raise NotImplementedError + + @abstractmethod + def get_cache(self): + raise NotImplementedError + + @abstractmethod + async def check_cache(self, data_id): + raise NotImplementedError + + @abstractmethod + async def load_cache(self, data_id, data): + raise NotImplementedError + +class RedisCache(Caches): + cache = None + + def init(self, redis_url): + self.cache = aiocache.RedisCache( + endpoint=redis_url.host, + port=redis_url.port, + password=redis_url.password, create_connection_timeout=5, ) - LOGGER.info("using SimpleMemoryCache") - return aiocache.SimpleMemoryCache(namespace=namespace) + + @functools.lru_cache() + def get_cache(self): + return self.cache + + async def check_cache(self, data_id): + cache = self.get_cache() + result = await cache.get(data_id, None) + LOGGER.info(f"{data_id} cache pulled") + await cache.close() + return result + + async def load_cache(self, data_id, data): + cache = self.get_cache() + await cache.set(data_id, data, ttl=3600) + LOGGER.info(f"{data_id} cache loaded") + await cache.close() + +class SimpleMemoryCache(Caches): + cache = None + def __init__(self): + self.cache = aiocache.SimpleMemoryCache() -async def check_cache(data_id: str, namespace: str = None): - """Check the data of a cache given an id.""" - cache = get_cache(namespace) - result = await cache.get(data_id, None) - LOGGER.info(f"{data_id} cache pulled") - await cache.close() - return result + @functools.lru_cache() + def get_cache(self): + return self.cache + async def check_cache(self, data_id): + cache = self.get_cache() + result = await cache.get(data_id, None) + LOGGER.info(f"{data_id} cache pulled") + await cache.close() + return result -async def load_cache(data_id: str, data, namespace: str = None, cache_life: int = 3600): - """Load data into the cache.""" - cache = get_cache(namespace) - await cache.set(data_id, data, ttl=cache_life) - LOGGER.info(f"{data_id} cache loaded") - await cache.close() + async def load_cache(self, data_id, data): + cache = self.get_cache() + await cache.set(data_id, data, ttl=3600) + LOGGER.info(f"{data_id} cache loaded") + await cache.close() diff --git a/app/data/__init__.py b/app/data/__init__.py index 60a75dac..dfab7b62 100644 --- a/app/data/__init__.py +++ b/app/data/__init__.py @@ -2,12 +2,26 @@ from ..services.location.csbs import CSBSLocationService from ..services.location.jhu import JhuLocationService from ..services.location.nyt import NYTLocationService +from ..config import get_settings +from ..caches import RedisCache, SimpleMemoryCache + +SETTINGS = get_settings() + +if SETTINGS.rediscloud_url: + REDIS_URL = SETTINGS.rediscloud_url +else: + REDIS_URL = SETTINGS.local_redis_url + +if REDIS_URL: + CACHE = RedisCache(REDIS_URL) +else: + CACHE = SimpleMemoryCache() # Mapping of services to data-sources. DATA_SOURCES = { - "jhu": JhuLocationService(), - "csbs": CSBSLocationService(), - "nyt": NYTLocationService(), + "jhu": JhuLocationService(CACHE), + "csbs": CSBSLocationService(CACHE), + "nyt": NYTLocationService(CACHE), } diff --git a/app/services/location/csbs.py b/app/services/location/csbs.py index 444ebad6..87c429d2 100644 --- a/app/services/location/csbs.py +++ b/app/services/location/csbs.py @@ -6,7 +6,7 @@ from asyncache import cached from cachetools import TTLCache -from ...caches import check_cache, load_cache +from ...caches import Caches from ...coordinates import Coordinates from ...location.csbs import CSBSLocation from ...utils import httputils @@ -14,12 +14,17 @@ LOGGER = logging.getLogger("services.location.csbs") +CACHE = None class CSBSLocationService(LocationService): """ Service for retrieving locations from csbs """ + def __init__(self, cache: Caches): + global CACHE + CACHE = cache + async def get_all(self): # Get the locations. locations = await get_locations() @@ -46,7 +51,7 @@ async def get_locations(): data_id = "csbs.locations" LOGGER.info(f"{data_id} Requesting data...") # check shared cache - cache_results = await check_cache(data_id) + cache_results = await CACHE.check_cache(data_id) if cache_results: LOGGER.info(f"{data_id} using shared cache results") locations = cache_results @@ -94,7 +99,7 @@ async def get_locations(): # save the results to distributed cache # TODO: fix json serialization try: - await load_cache(data_id, locations) + await CACHE.load_cache(data_id, locations) except TypeError as type_err: LOGGER.error(type_err) diff --git a/app/services/location/jhu.py b/app/services/location/jhu.py index ebed3960..16615401 100644 --- a/app/services/location/jhu.py +++ b/app/services/location/jhu.py @@ -8,7 +8,7 @@ from asyncache import cached from cachetools import TTLCache -from ...caches import check_cache, load_cache +from ...caches import Caches from ...coordinates import Coordinates from ...location import TimelinedLocation from ...models import Timeline @@ -19,13 +19,17 @@ LOGGER = logging.getLogger("services.location.jhu") PID = os.getpid() - +CACHE = None class JhuLocationService(LocationService): """ Service for retrieving locations from Johns Hopkins CSSE (https://github.com/CSSEGISandData/COVID-19). """ + def __init__(self, cache: Caches): + global CACHE + CACHE = cache + async def get_all(self): # Get the locations. locations = await get_locations() @@ -57,7 +61,7 @@ async def get_category(category): data_id = f"jhu.{category}" # check shared cache - cache_results = await check_cache(data_id) + cache_results = await CACHE.check_cache(data_id) if cache_results: LOGGER.info(f"{data_id} using shared cache results") results = cache_results @@ -121,7 +125,7 @@ async def get_category(category): "source": "https://github.com/ExpDev07/coronavirus-tracker-api", } # save the results to distributed cache - await load_cache(data_id, results) + await CACHE.load_cache(data_id, results) LOGGER.info(f"{data_id} results:\n{pf(results, depth=1)}") return results diff --git a/app/services/location/nyt.py b/app/services/location/nyt.py index 1f25ec34..529a9226 100644 --- a/app/services/location/nyt.py +++ b/app/services/location/nyt.py @@ -5,8 +5,9 @@ from asyncache import cached from cachetools import TTLCache +from pydantic.typing import NONE_TYPES -from ...caches import check_cache, load_cache +from ...caches import Caches from ...coordinates import Coordinates from ...location.nyt import NYTLocation from ...models import Timeline @@ -15,12 +16,17 @@ LOGGER = logging.getLogger("services.location.nyt") +CACHE = None class NYTLocationService(LocationService): """ Service for retrieving locations from New York Times (https://github.com/nytimes/covid-19-data). """ + def __init__(self, cache: Caches): + global CACHE + CACHE = cache + async def get_all(self): # Get the locations. locations = await get_locations() @@ -79,7 +85,7 @@ async def get_locations(): # Request the data. LOGGER.info(f"{data_id} Requesting data...") # check shared cache - cache_results = await check_cache(data_id) + cache_results = await CACHE.check_cache(data_id) if cache_results: LOGGER.info(f"{data_id} using shared cache results") locations = cache_results @@ -138,7 +144,7 @@ async def get_locations(): # save the results to distributed cache # TODO: fix json serialization try: - await load_cache(data_id, locations) + await CACHE.load_cache(data_id, locations) except TypeError as type_err: LOGGER.error(type_err)