diff --git a/app/routers/v2.py b/app/routers/v2.py index 31eb408c..98337acb 100644 --- a/app/routers/v2.py +++ b/app/routers/v2.py @@ -8,20 +8,38 @@ V2 = APIRouter() +class State(): -class Sources(str, enum.Enum): - """ - A source available for retrieving data. - """ + states = ["jhu", "csbs", "nyt"] + + def set_states(self, state): + if state.source in self.states: + self.__class__ = state + + def __str__(self): + return self.source - JHU = "jhu" - CSBS = "csbs" - NYT = "nyt" +class jhu(State): source = "jhu" +class csbs(State): source = "csbs" +class nyt(State): source = "nyt" + +class Sources(): + """ + A source available for retrieving data. + """ + + def __init__(self): + self.current = jhu() + + def set_state(self, state): + self.current.set_states(state) + +sources = Sources() @V2.get("/latest", response_model=LatestResponse) async def get_latest( - request: Request, source: Sources = Sources.JHU + request: Request, source: Sources = sources.current ): # pylint: disable=unused-argument """ Getting latest amount of total confirmed cases, deaths, and recoveries. @@ -40,7 +58,7 @@ async def get_latest( @V2.get("/locations", response_model=LocationsResponse, response_model_exclude_unset=True) async def get_locations( request: Request, - source: Sources = "jhu", + source: sources.current, country_code: str = None, province: str = None, county: str = None, @@ -93,7 +111,7 @@ async def get_locations( # pylint: disable=invalid-name @V2.get("/locations/{id}", response_model=LocationResponse) async def get_location_by_id( - request: Request, id: int, source: Sources = Sources.JHU, timelines: bool = True + request: Request, id: int, source: Sources = sources.current, timelines: bool = True ): """ Getting specific location by id.