diff --git a/app/data/__init__.py b/app/data/__init__.py index 60a75dac..e94b7d62 100644 --- a/app/data/__init__.py +++ b/app/data/__init__.py @@ -4,18 +4,8 @@ from ..services.location.nyt import NYTLocationService # Mapping of services to data-sources. -DATA_SOURCES = { - "jhu": JhuLocationService(), - "csbs": CSBSLocationService(), - "nyt": NYTLocationService(), -} - - -def data_source(source): - """ - Retrieves the provided data-source service. - - :returns: The service. - :rtype: LocationService - """ - return DATA_SOURCES.get(source.lower()) +DATA_SOURCES = [ + "jhu", + "csbs", + "nyt", +] diff --git a/app/main.py b/app/main.py index b9aff949..a1847183 100644 --- a/app/main.py +++ b/app/main.py @@ -14,7 +14,7 @@ from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from .config import get_settings -from .data import data_source +from .services.location import LocationServiceFactory from .routers import V1, V2 from .utils.httputils import setup_client_session, teardown_client_session @@ -74,7 +74,9 @@ async def add_datasource(request: Request, call_next): Attach the data source to the request.state. """ # Retrieve the datas ource from query param. - source = data_source(request.query_params.get("source", default="jhu")) + source_name = request.query_params.get("source", default="jhu") + location_service_factory = LocationServiceFactory() + source = location_service_factory.create(source_name) # Abort with 404 if source cannot be found. if not source: diff --git a/app/routers/v2.py b/app/routers/v2.py index 31eb408c..b912fb88 100644 --- a/app/routers/v2.py +++ b/app/routers/v2.py @@ -107,4 +107,4 @@ async def sources(): """ Retrieves a list of data-sources that are availble to use. """ - return {"sources": list(DATA_SOURCES.keys())} + return {"sources": DATA_SOURCES} diff --git a/app/services/location/locationservicefactory.py b/app/services/location/locationservicefactory.py new file mode 100644 index 00000000..e7fb8620 --- /dev/null +++ b/app/services/location/locationservicefactory.py @@ -0,0 +1,21 @@ + + +from ..location.csbs import CSBSLocationService +from ..location.jhu import JhuLocationService +from ..location.nyt import NYTLocationService + + +NEW_YORK_TIMES = 'nyt' +JOHNS_HOPKINS_UNIVERSITY = 'jhu' +CSBS = 'csbs' + + +class LocationServiceFactory: + def create(location_service_name): + lowercase_service_name = location_service_name.lower() + if lowercase_service_name == CSBS: + return CSBSLocationService() + elif lowercase_service_name == JOHNS_HOPKINS_UNIVERSITY: + return JhuLocationService() + elif lowercase_service_name == NEW_YORK_TIMES: + return NYTLocationService()