Skip to content

Commit 00dd327

Browse files
committed
use Source enum instead of magic strings
1 parent 31b93ea commit 00dd327

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

app/data/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class Sources(str, enum.Enum):
1818

1919
# Mapping of services to data-sources.
2020
DATA_SOURCES = {
21-
"jhu": JhuLocationService(),
22-
"csbs": CSBSLocationService(),
23-
"nyt": NYTLocationService(),
21+
Sources.jhu: JhuLocationService(),
22+
Sources.csbs: CSBSLocationService(),
23+
Sources.nyt: NYTLocationService(),
2424
}
2525

2626

app/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
1515

1616
from .config import get_settings
17-
from .data import data_source
17+
from .data import Sources, data_source
1818
from .routers import V1, V2
1919
from .utils.httputils import setup_client_session, teardown_client_session
2020

@@ -74,7 +74,7 @@ async def add_datasource(request: Request, call_next):
7474
Attach the data source to the request.state.
7575
"""
7676
# Retrieve the datas ource from query param.
77-
source = data_source(request.query_params.get("source", default="jhu"))
77+
source = data_source(request.query_params.get("source", default=Sources.jhu))
7878

7979
# Abort with 404 if source cannot be found.
8080
if not source:

app/routers/v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111

1212
@V2.get("/latest", response_model=LatestResponse)
13-
async def get_latest(request: Request, source: Sources = "jhu"): # pylint: disable=unused-argument
13+
async def get_latest(
14+
request: Request, source: Sources = Sources.jhu
15+
): # pylint: disable=unused-argument
1416
"""
1517
Getting latest amount of total confirmed cases, deaths, and recoveries.
1618
"""
@@ -28,7 +30,7 @@ async def get_latest(request: Request, source: Sources = "jhu"): # pylint: disa
2830
@V2.get("/locations", response_model=LocationsResponse, response_model_exclude_unset=True)
2931
async def get_locations(
3032
request: Request,
31-
source: Sources = "jhu",
33+
source: Sources = Sources.jhu,
3234
country_code: str = None,
3335
province: str = None,
3436
county: str = None,
@@ -81,7 +83,7 @@ async def get_locations(
8183
# pylint: disable=invalid-name
8284
@V2.get("/locations/{id}", response_model=LocationResponse)
8385
async def get_location_by_id(
84-
request: Request, id: int, source: Sources = "jhu", timelines: bool = True
86+
request: Request, id: int, source: Sources = Sources.jhu, timelines: bool = True
8587
):
8688
"""
8789
Getting specific location by id.

0 commit comments

Comments
 (0)