Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Ack JWT claims for authentication #94
  • Loading branch information
EliuX committed Apr 22, 2020
commit 1a41ed7077c121c595ab29f2b6ff2fcdf8a3618f
7 changes: 5 additions & 2 deletions requirements/time_tracker_api/prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ Flask-Script==2.0.6
#Semantic versioning
python-semantic-release==5.2.0

# The Debug Toolbar
#The Debug Toolbar
Flask-DebugToolbar==0.11.0

#CORS
flask-cors==3.0.8
flask-cors==3.0.8

#JWT
PyJWT==1.7.1
52 changes: 51 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
from datetime import datetime, timedelta

import jwt
import pytest
from faker import Faker
from flask import Flask
from flask import Flask, url_for
from flask.testing import FlaskClient

from commons.data_access_layer.cosmos_db import CosmosDBRepository, datetime_str, current_datetime
from time_tracker_api import create_app
from time_tracker_api.security import get_or_generate_dev_secret_key
from time_tracker_api.time_entries.time_entries_model import TimeEntryCosmosDBRepository

fake = Faker()
Faker.seed()

TEST_USER = {
"name": "[email protected]",
"password": "secret"
}


class User:
def __init__(self, username, password):
self.username = username
self.password = password


class AuthActions:
"""Auth actions container in tests"""

def __init__(self, app, client):
self._app = app
self._client = client

# def login(self, username=TEST_USER["name"],
# password=TEST_USER["password"]):
# login_url = url_for("security.login", self._app)
# return open_with_basic_auth(self._client,
# login_url,
# username,
# password)
#
# def logout(self):
# return self._client.get(url_for("security.logout", self._app),
# follow_redirects=True)


@pytest.fixture(scope='session')
def app() -> Flask:
Expand Down Expand Up @@ -148,3 +183,18 @@ def running_time_entry(time_entry_repository: TimeEntryCosmosDBRepository,

time_entry_repository.delete(id=created_time_entry.id,
partition_key_value=tenant_id)


@pytest.fixture(scope="session")
def valid_jwt(app: Flask) -> str:
expiration_time = datetime.utcnow() + timedelta(seconds=3600)
return jwt.encode({
"iss": "https://securityioet.b2clogin.com/%s/v2.0/" % fake.uuid4(),
"oid": fake.uuid4(),
'exp': expiration_time
}, key=get_or_generate_dev_secret_key()).decode("UTF-8")


@pytest.fixture(scope="session")
def valid_header(valid_jwt: str) -> dict:
return {'Authorization': "Bearer %s" % valid_jwt}
25 changes: 19 additions & 6 deletions tests/time_tracker_api/activities/activities_namespace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
}).update(valid_activity_data)


def test_create_activity_should_succeed_with_valid_request(client: FlaskClient, mocker: MockFixture):
def test_create_activity_should_succeed_with_valid_request(client: FlaskClient,
mocker: MockFixture,
valid_header: dict):
from time_tracker_api.activities.activities_namespace import activity_dao
repository_create_mock = mocker.patch.object(activity_dao.repository,
'create',
return_value=fake_activity)

response = client.post("/activities", json=valid_activity_data, follow_redirects=True)
response = client.post("/activities",
headers=valid_header,
json=valid_activity_data,
follow_redirects=True)

assert HTTPStatus.CREATED == response.status_code
repository_create_mock.assert_called_once()
Expand Down Expand Up @@ -57,7 +62,9 @@ def test_list_all_activities(client: FlaskClient, mocker: MockFixture):
repository_find_all_mock.assert_called_once()


def test_get_activity_should_succeed_with_valid_id(client: FlaskClient, mocker: MockFixture):
def test_get_activity_should_succeed_with_valid_id(client: FlaskClient,
mocker: MockFixture,
valid_header: dict):
from time_tracker_api.activities.activities_namespace import activity_dao

valid_id = fake.random_int(1, 9999)
Expand All @@ -66,15 +73,19 @@ def test_get_activity_should_succeed_with_valid_id(client: FlaskClient, mocker:
'find',
return_value=fake_activity)

response = client.get("/activities/%s" % valid_id, follow_redirects=True)
response = client.get("/activities/%s" % valid_id,
headers=valid_header,
follow_redirects=True)

assert HTTPStatus.OK == response.status_code
fake_activity == json.loads(response.data)
repository_find_mock.assert_called_once_with(str(valid_id),
partition_key_value=current_user_tenant_id())


def test_get_activity_should_return_not_found_with_invalid_id(client: FlaskClient, mocker: MockFixture):
def test_get_activity_should_return_not_found_with_invalid_id(client: FlaskClient,
mocker: MockFixture,
valid_header: dict):
from time_tracker_api.activities.activities_namespace import activity_dao
from werkzeug.exceptions import NotFound

Expand All @@ -84,7 +95,9 @@ def test_get_activity_should_return_not_found_with_invalid_id(client: FlaskClien
'find',
side_effect=NotFound)

response = client.get("/activities/%s" % invalid_id, follow_redirects=True)
response = client.get("/activities/%s" % invalid_id,
headers=valid_header,
follow_redirects=True)

assert HTTPStatus.NOT_FOUND == response.status_code
repository_find_mock.assert_called_once_with(str(invalid_id),
Expand Down
34 changes: 34 additions & 0 deletions tests/time_tracker_api/security_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from time_tracker_api.security import parse_jwt, parse_tenant_id_from_iss_claim


def test_parse_jwt_with_valid_input(valid_jwt: str):
result = parse_jwt("Bearer %s" % valid_jwt)

assert result is not None
assert type(result) is dict


def test_parse_jwt_with_invalid_input():
result = parse_jwt("whetever")

assert result is None


def test_parse_tenant_id_from_iss_claim_with_valid_input():
valid_iss_claim = "https://securityioet.b2clogin.com/b21c4e98-c4bf-420f-9d76-e51c2515c7a4/v2.0/"

result = parse_tenant_id_from_iss_claim(valid_iss_claim)

assert result is not None
assert type(result) is str
assert result == "b21c4e98-c4bf-420f-9d76-e51c2515c7a4"


def test_parse_tenant_id_from_iss_claim_with_invalid_input():
invalid_iss_claim1 = "https://securityioet.b2clogin.com/whatever/v2.0/"
invalid_iss_claim2 = ""

result1 = parse_tenant_id_from_iss_claim(invalid_iss_claim1)
result2 = parse_tenant_id_from_iss_claim(invalid_iss_claim2)

assert result1 == result2 == None
5 changes: 4 additions & 1 deletion time_tracker_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from flask_restplus._http import HTTPStatus

from commons.data_access_layer.cosmos_db import CustomError
from time_tracker_api import security
from time_tracker_api.version import __version__

faker = Faker()

api = Api(
version=__version__,
title="TimeTracker API",
description="API for the TimeTracker project"
description="API for the TimeTracker project",
authorizations=security.authorizations,
security="TimeTracker JWT",
)

# For matching UUIDs
Expand Down
4 changes: 2 additions & 2 deletions time_tracker_api/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os

from time_tracker_api.security import generate_dev_secret_key
from time_tracker_api.security import get_or_generate_dev_secret_key

DISABLE_STR_VALUES = ("false", "0", "disabled")


class Config:
SECRET_KEY = generate_dev_secret_key()
SECRET_KEY = get_or_generate_dev_secret_key()
SQL_DATABASE_URI = os.environ.get('SQL_DATABASE_URI')
PROPAGATE_EXCEPTIONS = True
RESTPLUS_VALIDATE = True
Expand Down
97 changes: 79 additions & 18 deletions time_tracker_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,97 @@
This is where we handle everything regarding to authorization
and authentication. Also stores helper functions related to it.
"""
import re

import jwt
from faker import Faker
from flask import request
from flask_restplus import abort
from flask_restplus._http import HTTPStatus
from jwt import DecodeError, ExpiredSignatureError

fake = Faker()

dev_secret_key: str = None

authorizations = {
"TimeTracker JWT": {
'type': 'apiKey',
'in': 'header',
'name': 'Authorization',
'description': "Specify in the value **'Bearer <JWT>'**, where JWT is the token",
}
}

iss_claim_pattern = re.compile(
r"securityioet.b2clogin.com/(?P<tenant_id>[0-9a-f]{8}\-[0-9a-f]{4}\-4[0-9a-f]{3}\-[89ab][0-9a-f]{3}\-[0-9a-f]{12})")


def current_user_id() -> str:
"""
Returns the id of the authenticated user in
Azure Active Directory
"""
return 'anonymous'
oid_claim = get_token_json().get("oid")
if oid_claim is None:
abort(message='The claim "oid" is missing in the JWT', code=HTTPStatus.UNAUTHORIZED)

return oid_claim


def current_user_tenant_id() -> str:
# TODO Get this from the JWT
return "ioet"
iss_claim = get_token_json().get("iss")
if iss_claim is None:
abort(message='The claim "iss" is missing in the JWT', code=HTTPStatus.UNAUTHORIZED)

tenant_id = parse_tenant_id_from_iss_claim(iss_claim)
if tenant_id is None:
abort(message='The format of the claim "iss" cannot be understood. '
'Please contact the development team.',
code=HTTPStatus.UNAUTHORIZED)

return tenant_id

def generate_dev_secret_key():
from time_tracker_api import flask_app as app
"""
Generates a security key for development purposes
:return: str
"""

def get_or_generate_dev_secret_key():
global dev_secret_key
dev_secret_key = fake.password(length=16, special_chars=True, digits=True, upper_case=True, lower_case=True)
if app.config.get("FLASK_DEBUG", False): # pragma: no cover
print('*********************************************************')
print("The generated secret is \"%s\"" % dev_secret_key)
print('*********************************************************')
if dev_secret_key is None:
from time_tracker_api import flask_app as app
"""
Generates a security key for development purposes
:return: str
"""
dev_secret_key = fake.password(length=16, special_chars=True, digits=True, upper_case=True, lower_case=True)
if app.config.get("FLASK_DEBUG", False): # pragma: no cover
print('*********************************************************')
print("The generated secret is \"%s\"" % dev_secret_key)
print('*********************************************************')
return dev_secret_key


def parse_jwt(authentication_header_content):
if authentication_header_content is not None:
parsed_content = authentication_header_content.split("Bearer ")

if len(parsed_content) > 1:
return jwt.decode(parsed_content[1], verify=False)

return None


def get_authorization_jwt():
auth_header = request.headers.get('Authorization')
return parse_jwt(auth_header)


def get_token_json():
try:
return get_authorization_jwt()
except DecodeError:
abort(message='Malformed token', code=HTTPStatus.UNAUTHORIZED)
except ExpiredSignatureError:
abort(message='Expired token', code=HTTPStatus.UNAUTHORIZED)


def parse_tenant_id_from_iss_claim(iss_claim: str) -> str:
m = iss_claim_pattern.search(iss_claim)
if m is not None:
return m.group('tenant_id')

return None