diff --git a/app/data/__init__.py b/app/data/__init__.py index 60a75dac..455a0ce0 100644 --- a/app/data/__init__.py +++ b/app/data/__init__.py @@ -2,20 +2,45 @@ from ..services.location.csbs import CSBSLocationService from ..services.location.jhu import JhuLocationService from ..services.location.nyt import NYTLocationService +from ..utils.singleton import Singleton -# Mapping of services to data-sources. -DATA_SOURCES = { - "jhu": JhuLocationService(), - "csbs": CSBSLocationService(), - "nyt": NYTLocationService(), -} - -def data_source(source): +class DataSources(Singleton): """ - Retrieves the provided data-source service. - - :returns: The service. - :rtype: LocationService + Class to represent the root of the aggregate containing the location services. """ - return DATA_SOURCES.get(source.lower()) + + # Mapping of services to data-sources. + __DATA_SOURCES_MAP = { + "jhu": JhuLocationService(), + "csbs": CSBSLocationService(), + "nyt": NYTLocationService(), + } + + __instance = None + + def __init__(self): + pass + + def get_instance(): + if DataSources.__instance is None: + DataSources.__instance = DataSources() + return DataSources.__instance + + def get_data_source(self, source): + """ + Retrieves the provided data-source service. + + :returns: The service. + :rtype: LocationService + """ + return self.__DATA_SOURCES_MAP.get(source.lower()) + + def get_data_sources(self): + """ + Retrieves a dict of all data sources. + + :returns: The dictionary of data sources. + :rtype: dict + """ + return self.__DATA_SOURCES_MAP diff --git a/app/main.py b/app/main.py index b9aff949..d6b480e2 100644 --- a/app/main.py +++ b/app/main.py @@ -14,7 +14,7 @@ from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from .config import get_settings -from .data import data_source +from .data import DataSources from .routers import V1, V2 from .utils.httputils import setup_client_session, teardown_client_session @@ -41,6 +41,8 @@ on_shutdown=[teardown_client_session], ) +DATA_SOURCES = DataSources.get_instance() + # ##################### # Middleware ####################### @@ -73,8 +75,8 @@ async def add_datasource(request: Request, call_next): """ Attach the data source to the request.state. """ - # Retrieve the datas ource from query param. - source = data_source(request.query_params.get("source", default="jhu")) + # Retrieve the data source from query param. + source = DATA_SOURCES.get_data_source(request.query_params.get("source", default="jhu")) # Abort with 404 if source cannot be found. if not source: diff --git a/app/routers/v2.py b/app/routers/v2.py index 31eb408c..96cdd67d 100644 --- a/app/routers/v2.py +++ b/app/routers/v2.py @@ -3,10 +3,11 @@ from fastapi import APIRouter, HTTPException, Request -from ..data import DATA_SOURCES +from ..data import DataSources from ..models import LatestResponse, LocationResponse, LocationsResponse V2 = APIRouter() +DATA_SOURCES = DataSources.get_instance() class Sources(str, enum.Enum): @@ -107,4 +108,4 @@ async def sources(): """ Retrieves a list of data-sources that are availble to use. """ - return {"sources": list(DATA_SOURCES.keys())} + return {"sources": list(DATA_SOURCES.get_data_sources().keys())} diff --git a/app/utils/singleton.py b/app/utils/singleton.py new file mode 100644 index 00000000..501e4fbd --- /dev/null +++ b/app/utils/singleton.py @@ -0,0 +1,11 @@ +class Singleton(object): + def __new__(cls, *args, **kwds): + it = cls.__dict__.get("__it__") + if it is not None: + return it + cls.__it__ = it = object.__new__(cls) + it.init(*args, **kwds) + return it + + def init(self, *args, **kwds): + pass