diff --git a/app/data/__init__.py b/app/data/__init__.py index 60a75dac..90d236be 100644 --- a/app/data/__init__.py +++ b/app/data/__init__.py @@ -3,19 +3,33 @@ from ..services.location.jhu import JhuLocationService from ..services.location.nyt import NYTLocationService -# Mapping of services to data-sources. -DATA_SOURCES = { - "jhu": JhuLocationService(), - "csbs": CSBSLocationService(), - "nyt": NYTLocationService(), -} +class Source: + def __init__(self, source) -> None: + self._sources = {"jhu","csbs","nyt"} + if source == "csbs": + self._service = CSBSLocationService() + elif source == "nyt": + self._service = NYTLocationService() + else: + self._service = JhuLocationService() -def data_source(source): - """ - Retrieves the provided data-source service. + if source not in self._sources: + self._service = None + + def get_sources(self): + """ + Return the list of available sources. + """ + return self._sources - :returns: The service. - :rtype: LocationService - """ - return DATA_SOURCES.get(source.lower()) + def get_service(self): + """ + Retrieves the provided data-source service. + + :returns: The service. + :rtype: LocationService + """ + return self._service + + \ No newline at end of file diff --git a/app/main.py b/app/main.py index b9aff949..3364c40e 100644 --- a/app/main.py +++ b/app/main.py @@ -14,7 +14,7 @@ from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from .config import get_settings -from .data import data_source +from .data import Source from .routers import V1, V2 from .utils.httputils import setup_client_session, teardown_client_session @@ -74,10 +74,10 @@ async def add_datasource(request: Request, call_next): Attach the data source to the request.state. """ # Retrieve the datas ource from query param. - source = data_source(request.query_params.get("source", default="jhu")) + source = Source(request.query_params.get("source", default="jhu")) # Abort with 404 if source cannot be found. - if not source: + if not source.get_service(): return Response("The provided data-source was not found.", status_code=404) # Attach source to request. diff --git a/app/routers/v2.py b/app/routers/v2.py index 31eb408c..bc72857b 100644 --- a/app/routers/v2.py +++ b/app/routers/v2.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, HTTPException, Request -from ..data import DATA_SOURCES from ..models import LatestResponse, LocationResponse, LocationsResponse V2 = APIRouter() @@ -26,7 +25,7 @@ async def get_latest( """ Getting latest amount of total confirmed cases, deaths, and recoveries. """ - locations = await request.state.source.get_all() + locations = await request.state.source.get_service().get_all() return { "latest": { "confirmed": sum(map(lambda location: location.confirmed, locations)), @@ -57,7 +56,7 @@ async def get_locations( params.pop("timelines", None) # Retrieve all the locations. - locations = await request.state.source.get_all() + locations = await request.state.source.get_service().get_all() # Attempt to filter out locations with properties matching the provided query params. for key, value in params.items(): @@ -98,7 +97,7 @@ async def get_location_by_id( """ Getting specific location by id. """ - location = await request.state.source.get(id) + location = await request.state.source.get_service().get(id) return {"location": location.serialize(timelines)} @@ -107,4 +106,4 @@ async def sources(): """ Retrieves a list of data-sources that are availble to use. """ - return {"sources": list(DATA_SOURCES.keys())} + return {"sources": [source.value for source in Sources]}