Skip to content

Commit aa8b1f3

Browse files
committed
make all services into one service which can be injected with gateway class
1 parent 1c7e4ae commit aa8b1f3

File tree

6 files changed

+433
-422
lines changed

6 files changed

+433
-422
lines changed

app/data/__init__.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
11
"""app.data"""
2-
from ..services.location.csbs import CSBSLocationService
3-
from ..services.location.jhu import JhuLocationService
4-
from ..services.location.nyt import NYTLocationService
2+
from ..services.location.csbs import CSBSGateway
3+
from ..services.location.jhu import JHUGateway
4+
from ..services.location.nyt import NYTGateway
5+
6+
from ..services.location import LocationGateway, LocationService
7+
8+
9+
10+
class ServiceFactory:
11+
12+
def create_service(self, source_name: str):
13+
source_name = source_name.lower()
14+
15+
gateway: LocationGateway
16+
17+
if source_name == 'jhu':
18+
gateway = JHUGateway("https://raw.githubusercontent.com/CSSEGISandData/2019-nCoV/master/csse_covid_19_data/csse_covid_19_time_series/")
19+
elif source_name == 'csbs':
20+
gateway = CSBSGateway("https://facts.csbs.org/covid-19/covid19_county.csv")
21+
elif source_name == 'nyt':
22+
gateway = NYTGateway("https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-counties.csv")
23+
24+
service: LocationService = LocationService(gateway)
25+
526

627
# Mapping of services to data-sources.
728
DATA_SOURCES = {
8-
"jhu": JhuLocationService(),
9-
"csbs": CSBSLocationService(),
10-
"nyt": NYTLocationService(),
29+
"jhu": ServiceFactory().create_service("jhu"),
30+
"csbs": ServiceFactory().create_service("csbs"),
31+
"nyt": ServiceFactory().create_service("nyt"),
1132
}
1233

1334

@@ -19,3 +40,7 @@ def data_source(source):
1940
:rtype: LocationService
2041
"""
2142
return DATA_SOURCES.get(source.lower())
43+
44+
45+
46+

app/routers/v1.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
"""app.routers.v1.py"""
22
from fastapi import APIRouter
33

4-
from ..services.location.jhu import get_category
4+
from ..services.location.jhu import JHUGateway
55

66
V1 = APIRouter()
7+
gateway = JHUGateway("https://raw.githubusercontent.com/CSSEGISandData/2019-nCoV/master/csse_covid_19_data/csse_covid_19_time_series/")
78

89

910
@V1.get("/all")
1011
async def all_categories():
1112
"""Get all the categories."""
12-
confirmed = await get_category("confirmed")
13-
deaths = await get_category("deaths")
14-
recovered = await get_category("recovered")
13+
confirmed = await gateway.get_category("confirmed")
14+
deaths = await gateway.get_category("deaths")
15+
recovered = await gateway.get_category("recovered")
1516

1617
return {
1718
# Data.
@@ -30,22 +31,22 @@ async def all_categories():
3031
@V1.get("/confirmed")
3132
async def get_confirmed():
3233
"""Confirmed cases."""
33-
confirmed_data = await get_category("confirmed")
34+
confirmed_data = await gateway.get_category("confirmed")
3435

3536
return confirmed_data
3637

3738

3839
@V1.get("/deaths")
3940
async def get_deaths():
4041
"""Total deaths."""
41-
deaths_data = await get_category("deaths")
42+
deaths_data = await gateway.get_category("deaths")
4243

4344
return deaths_data
4445

4546

4647
@V1.get("/recovered")
4748
async def get_recovered():
4849
"""Recovered cases."""
49-
recovered_data = await get_category("recovered")
50+
recovered_data = await gateway.get_category("recovered")
5051

5152
return recovered_data

app/services/location/__init__.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,39 @@
22
from abc import ABC, abstractmethod
33

44

5-
class LocationService(ABC):
5+
class LocationService:
66
"""
77
Service for retrieving locations.
88
"""
9+
"""
10+
Service for retrieving locations from csbs
11+
"""
12+
13+
def __init__(self, gateway: "LocationGateway"):
14+
self.gateway = gateway
915

10-
@abstractmethod
1116
async def get_all(self):
12-
"""
13-
Gets and returns all of the locations.
17+
# Get the locations.
18+
locations = await self.gateway.get_locations()
19+
return locations
1420

15-
:returns: The locations.
16-
:rtype: List[Location]
17-
"""
18-
raise NotImplementedError
21+
async def get(self, loc_id): # pylint: disable=arguments-differ
22+
# Get location at the index equal to the provided id.
23+
locations = await self.get_all()
24+
return locations[loc_id]
25+
26+
def set_gateway(self, gateway: "LocationGateway"):
27+
self.gateway = gateway
28+
29+
30+
class LocationGateway(ABC):
31+
"""
32+
real processing for all kinds of locations
33+
"""
1934

2035
@abstractmethod
21-
async def get(self, id): # pylint: disable=redefined-builtin,invalid-name
36+
async def get_locations(self):
2237
"""
23-
Gets and returns location with the provided id.
24-
25-
:returns: The location.
26-
:rtype: Location
38+
parse all locations from the datasource
2739
"""
2840
raise NotImplementedError

app/services/location/csbs.py

Lines changed: 76 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,93 +10,86 @@
1010
from ...coordinates import Coordinates
1111
from ...location.csbs import CSBSLocation
1212
from ...utils import httputils
13-
from . import LocationService
13+
from . import LocationService, LocationGateway
1414

1515
LOGGER = logging.getLogger("services.location.csbs")
1616

1717

18-
class CSBSLocationService(LocationService):
19-
"""
20-
Service for retrieving locations from csbs
21-
"""
2218

23-
async def get_all(self):
24-
# Get the locations.
25-
locations = await get_locations()
26-
return locations
2719

28-
async def get(self, loc_id): # pylint: disable=arguments-differ
29-
# Get location at the index equal to the provided id.
30-
locations = await self.get_all()
31-
return locations[loc_id]
32-
33-
34-
# Base URL for fetching data
35-
BASE_URL = "https://facts.csbs.org/covid-19/covid19_county.csv"
36-
37-
38-
@cached(cache=TTLCache(maxsize=1, ttl=1800))
39-
async def get_locations():
40-
"""
41-
Retrieves county locations; locations are cached for 1 hour
42-
43-
:returns: The locations.
44-
:rtype: dict
45-
"""
46-
data_id = "csbs.locations"
47-
LOGGER.info(f"{data_id} Requesting data...")
48-
# check shared cache
49-
cache_results = await check_cache(data_id)
50-
if cache_results:
51-
LOGGER.info(f"{data_id} using shared cache results")
52-
locations = cache_results
53-
else:
54-
LOGGER.info(f"{data_id} shared cache empty")
55-
async with httputils.CLIENT_SESSION.get(BASE_URL) as response:
56-
text = await response.text()
57-
58-
LOGGER.debug(f"{data_id} Data received")
59-
60-
data = list(csv.DictReader(text.splitlines()))
61-
LOGGER.debug(f"{data_id} CSV parsed")
62-
63-
locations = []
64-
65-
for i, item in enumerate(data):
66-
# General info.
67-
state = item["State Name"]
68-
county = item["County Name"]
69-
70-
# Ensure country is specified.
71-
if county in {"Unassigned", "Unknown"}:
72-
continue
73-
74-
# Date string without "EDT" at end.
75-
last_update = " ".join(item["Last Update"].split(" ")[0:2])
76-
77-
# Append to locations.
78-
locations.append(
79-
CSBSLocation(
80-
# General info.
81-
i,
82-
state,
83-
county,
84-
# Coordinates.
85-
Coordinates(item["Latitude"], item["Longitude"]),
86-
# Last update (parse as ISO).
87-
datetime.strptime(last_update, "%Y-%m-%d %H:%M").isoformat() + "Z",
88-
# Statistics.
89-
int(item["Confirmed"] or 0),
90-
int(item["Death"] or 0),
20+
21+
class CSBSGateway(LocationGateway):
22+
23+
def __init__(self, base_url):
24+
self.BASE_URL = base_url
25+
26+
@cached(cache=TTLCache(maxsize=1, ttl=1800))
27+
async def get_locations(self):
28+
"""
29+
Retrieves county locations; locations are cached for 1 hour
30+
31+
:returns: The locations.
32+
:rtype: dict
33+
"""
34+
data_id = "csbs.locations"
35+
LOGGER.info(f"{data_id} Requesting data...")
36+
# check shared cache
37+
cache_results = await check_cache(data_id)
38+
if cache_results:
39+
LOGGER.info(f"{data_id} using shared cache results")
40+
locations = cache_results
41+
else:
42+
LOGGER.info(f"{data_id} shared cache empty")
43+
async with httputils.CLIENT_SESSION.get(self.BASE_URL) as response:
44+
text = await response.text()
45+
46+
LOGGER.debug(f"{data_id} Data received")
47+
48+
data = list(csv.DictReader(text.splitlines()))
49+
LOGGER.debug(f"{data_id} CSV parsed")
50+
51+
locations = []
52+
53+
for i, item in enumerate(data):
54+
# General info.
55+
state = item["State Name"]
56+
county = item["County Name"]
57+
58+
# Ensure country is specified.
59+
if county in {"Unassigned", "Unknown"}:
60+
continue
61+
62+
# Date string without "EDT" at end.
63+
last_update = " ".join(item["Last Update"].split(" ")[0:2])
64+
65+
# Append to locations.
66+
locations.append(
67+
CSBSLocation(
68+
# General info.
69+
i,
70+
state,
71+
county,
72+
# Coordinates.
73+
Coordinates(item["Latitude"], item["Longitude"]),
74+
# Last update (parse as ISO).
75+
datetime.strptime(last_update, "%Y-%m-%d %H:%M").isoformat() + "Z",
76+
# Statistics.
77+
int(item["Confirmed"] or 0),
78+
int(item["Death"] or 0),
79+
)
9180
)
92-
)
93-
LOGGER.info(f"{data_id} Data normalized")
94-
# save the results to distributed cache
95-
# TODO: fix json serialization
96-
try:
97-
await load_cache(data_id, locations)
98-
except TypeError as type_err:
99-
LOGGER.error(type_err)
100-
101-
# Return the locations.
102-
return locations
81+
LOGGER.info(f"{data_id} Data normalized")
82+
# save the results to distributed cache
83+
# TODO: fix json serialization
84+
try:
85+
await load_cache(data_id, locations)
86+
except TypeError as type_err:
87+
LOGGER.error(type_err)
88+
89+
# Return the locations.
90+
return locations
91+
92+
93+
94+
95+

0 commit comments

Comments
 (0)