Skip to content

Commit 0e11daa

Browse files
committed
Update test_jhu to handle asyncio
- Move test fixtures into tests/fixtures.py - Update test fixtures to mock aiohttp.ClientSession.get, instead of requests.get - Add a context manager to replace the global httputils.client_session with an AsyncMock - Misc. cleanup
1 parent 63675dc commit 0e11daa

File tree

2 files changed

+103
-69
lines changed

2 files changed

+103
-69
lines changed

tests/fixtures.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import datetime
2+
import os
3+
from contextlib import asynccontextmanager
4+
from unittest import mock
5+
6+
from app.utils import httputils
7+
8+
9+
class DateTimeStrpTime:
10+
"""Returns instance of `DateTimeStrpTime`
11+
when calling `app.services.location.jhu.datetime.trptime(date, '%m/%d/%y').isoformat()`.
12+
"""
13+
14+
def __init__(self, date, strformat):
15+
self.date = date
16+
self.strformat = strformat
17+
18+
def isoformat(self):
19+
return datetime.datetime.strptime(self.date, self.strformat).isoformat()
20+
21+
22+
class FakeRequestsGetResponse:
23+
"""Fake instance of a response from `aiohttp.ClientSession.get`.
24+
"""
25+
26+
def __init__(self, url, filename, state):
27+
self.url = url
28+
self.filename = filename
29+
self.state = state
30+
31+
async def text(self):
32+
return self.read_file(self.state)
33+
34+
def read_file(self, state):
35+
"""
36+
Mock HTTP GET-method and return text from file
37+
"""
38+
state = state.lower()
39+
40+
# Determine filepath.
41+
filepath = os.path.join(os.path.dirname(__file__), "example_data/{}.csv".format(state))
42+
43+
# Return fake response.
44+
print("Try to read {}".format(filepath))
45+
with open(filepath, "r") as file:
46+
return file.read()
47+
48+
49+
@asynccontextmanager
50+
async def mock_client_session():
51+
"""Context manager that replaces the global client_session with an AsyncMock instance.
52+
53+
:Example:
54+
55+
>>> async with mock_client_session() as mocked_client_session:
56+
>>> mocked_client_session.get = mocked_session_get
57+
>>> # test code...
58+
59+
"""
60+
61+
httputils.client_session = mocked_client_session = mock.AsyncMock()
62+
try:
63+
yield mocked_client_session
64+
finally:
65+
del httputils.client_session
66+
67+
68+
@asynccontextmanager
69+
async def mocked_session_get(*args, **kwargs):
70+
"""Mock response from client_session.get.
71+
"""
72+
73+
url = args[0]
74+
filename = url.split("/")[-1]
75+
76+
# clean up for id token (e.g. Deaths)
77+
state = filename.split("-")[-1].replace(".csv", "").lower().capitalize()
78+
79+
yield FakeRequestsGetResponse(url, filename, state)
80+
81+
82+
def mocked_strptime_isoformat(*args, **kwargs):
83+
"""Mock return value from datetime.strptime().isoformat().
84+
"""
85+
86+
date = args[0]
87+
strformat = args[1]
88+
89+
return DateTimeStrpTime(date, strformat)

tests/test_jhu.py

Lines changed: 14 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,26 @@
66
import app
77
from app import location
88
from app.services.location import jhu
9-
from app.utils import date
9+
from tests.fixtures import mock_client_session
10+
from tests.fixtures import mocked_session_get
11+
from tests.fixtures import mocked_strptime_isoformat
1012

1113
DATETIME_STRING = "2020-03-17T10:23:22.505550"
1214

1315

14-
def mocked_requests_get(*args, **kwargs):
15-
class FakeRequestsGetResponse:
16-
"""
17-
Returns instance of `FakeRequestsGetResponse`
18-
when calling `app.services.location.jhu.requests.get()`
19-
"""
20-
21-
def __init__(self, url, filename, state):
22-
self.url = url
23-
self.filename = filename
24-
self.state = state
25-
self.text = self.read_file(self.state)
26-
27-
def read_file(self, state):
28-
"""
29-
Mock HTTP GET-method and return text from file
30-
"""
31-
state = state.lower()
32-
33-
# Determine filepath.
34-
filepath = "tests/example_data/{}.csv".format(state)
35-
36-
# Return fake response.
37-
print("Try to read {}".format(filepath))
38-
with open(filepath, "r") as file:
39-
return file.read()
40-
41-
# get url from `request.get`
42-
url = args[0]
43-
44-
# get filename from url
45-
filename = url.split("/")[-1]
46-
47-
# clean up for id token (e.g. Deaths)
48-
state = filename.split("-")[-1].replace(".csv", "").lower().capitalize()
49-
50-
return FakeRequestsGetResponse(url, filename, state)
51-
52-
53-
def mocked_strptime_isoformat(*args, **kwargs):
54-
class DateTimeStrpTime:
55-
"""
56-
Returns instance of `DateTimeStrpTime`
57-
when calling `app.services.location.jhu.datetime.trptime(date, '%m/%d/%y').isoformat()`
58-
"""
59-
60-
def __init__(self, date, strformat):
61-
self.date = date
62-
self.strformat = strformat
63-
64-
def isoformat(self):
65-
return datetime.datetime.strptime(self.date, self.strformat).isoformat()
66-
67-
date = args[0]
68-
strformat = args[1]
69-
70-
return DateTimeStrpTime(date, strformat)
71-
72-
16+
@pytest.mark.asyncio
7317
@mock.patch("app.services.location.jhu.datetime")
74-
@mock.patch("app.services.location.jhu.requests.get", side_effect=mocked_requests_get)
75-
def test_get_locations(mock_request_get, mock_datetime):
76-
# mock app.services.location.jhu.datetime.utcnow().isoformat()
18+
async def test_get_locations(mock_datetime):
7719
mock_datetime.utcnow.return_value.isoformat.return_value = DATETIME_STRING
7820
mock_datetime.strptime.side_effect = mocked_strptime_isoformat
7921

80-
output = jhu.get_locations()
81-
assert isinstance(output, list)
82-
assert isinstance(output[0], location.Location)
22+
async with mock_client_session() as mocked_client_session:
23+
mocked_client_session.get = mocked_session_get
24+
output = await jhu.get_locations()
25+
26+
assert isinstance(output, list)
27+
assert isinstance(output[0], location.Location)
8328

84-
# `jhu.get_locations()` creates id based on confirmed list
85-
location_confirmed = jhu.get_category("confirmed")
86-
assert len(output) == len(location_confirmed["locations"])
29+
# `jhu.get_locations()` creates id based on confirmed list
30+
location_confirmed = await jhu.get_category("confirmed")
31+
assert len(output) == len(location_confirmed["locations"])

0 commit comments

Comments
 (0)