diff --git a/app/routers/v2.py b/app/routers/v2.py index 31eb408c..fd91b292 100644 --- a/app/routers/v2.py +++ b/app/routers/v2.py @@ -2,7 +2,7 @@ import enum from fastapi import APIRouter, HTTPException, Request - +from ..services import ServiceRoot from ..data import DATA_SOURCES from ..models import LatestResponse, LocationResponse, LocationsResponse @@ -26,7 +26,11 @@ async def get_latest( """ Getting latest amount of total confirmed cases, deaths, and recoveries. """ - locations = await request.state.source.get_all() + + #service root for aggregate pattern + service = ServiceRoot(source) + + locations = await service.get_all() return { "latest": { "confirmed": sum(map(lambda location: location.confirmed, locations)), @@ -56,8 +60,10 @@ async def get_locations( params.pop("source", None) params.pop("timelines", None) + service = ServiceRoot(source) + # Retrieve all the locations. - locations = await request.state.source.get_all() + locations = await service.get_all() # Attempt to filter out locations with properties matching the provided query params. for key, value in params.items(): @@ -98,7 +104,12 @@ async def get_location_by_id( """ Getting specific location by id. """ - location = await request.state.source.get(id) + + #aggregate root for service layer + service = ServiceRoot(source) + + # Retrieve location. + location = await service.get(id) return {"location": location.serialize(timelines)} diff --git a/app/services/location/service_root.py b/app/services/location/service_root.py new file mode 100644 index 00000000..775049b9 --- /dev/null +++ b/app/services/location/service_root.py @@ -0,0 +1,26 @@ +from jhu import JHULocationService +from nyt import NYTLocationService +from csbs import CSBSLocationService + +class ServiceRoot: + + def __init__(self, source, service): + if source == '' or source == None or source == 'jhu': + self.source = 'jhu' + self.service = JHULocationService() + elif source == 'nyt': + self.source = source + self.service = NYTLocationService() + elif source == 'csbs': + self.source = source + self.service = CSBSLocationService() + else + self.source = None + + async def get_all(self): + locations = await self.service.get_all() + return locations + + async def get(self, id): + location = await self.service.get(id) + return location