Skip to content

Commit

Permalink
Merge pull request #309 from AzureAD/release-1.9.0
Browse files Browse the repository at this point in the history
Release 1.9.0
  • Loading branch information
rayluo authored Feb 9, 2021
2 parents 82f9f0c + 2616d89 commit 72a7250
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 137 deletions.
89 changes: 89 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: CI/CD

on:
push:
pull_request:
branches: [ dev ]

jobs:
ci:
env:
# Fake a TRAVIS env so that the pre-existing test cases would behave like before
TRAVIS: true
LAB_APP_CLIENT_ID: ${{ secrets.LAB_APP_CLIENT_ID }}
LAB_APP_CLIENT_SECRET: ${{ secrets.LAB_APP_CLIENT_SECRET }}
LAB_OBO_CLIENT_SECRET: ${{ secrets.LAB_OBO_CLIENT_SECRET }}
LAB_OBO_CONFIDENTIAL_CLIENT_ID: ${{ secrets.LAB_OBO_CONFIDENTIAL_CLIENT_ID }}
LAB_OBO_PUBLIC_CLIENT_ID: ${{ secrets.LAB_OBO_PUBLIC_CLIENT_ID }}

# Derived from https://docs.github.com/en/actions/guides/building-and-testing-python#starting-with-the-python-workflow-template
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [2.7, 3.5, 3.6, 3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

# Derived from https://github.com/actions/cache/blob/main/examples.md#using-pip-to-get-cache-location
# However, a before-and-after test shows no improvement in this repo,
# possibly because the bottlenect was not in downloading those small python deps.
- name: Get pip cache dir from pip 20.1+
id: pip-cache
run: |
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt') }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
#flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
#flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
cd:
needs: ci
if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags') || github.ref == 'refs/heads/main')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Build a package for release
run: |
python -m pip install build --user
python -m build --sdist --wheel --outdir dist/ .
- name: Publish to TestPyPI
uses: pypa/[email protected]
if: github.ref == 'refs/heads/main'
with:
user: __token__
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository_url: https://test.pypi.org/legacy/
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags')
uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
69 changes: 40 additions & 29 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "1.8.0"
__version__ = "1.9.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +56,9 @@ def decorate_scope(
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
return str(uuid.uuid4())
correlation_id = str(uuid.uuid4())
logger.debug("Generates correlation_id: %s", correlation_id)
return correlation_id


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
Expand Down Expand Up @@ -439,16 +441,20 @@ def get_authorization_request_url(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id,
http_client=self.http_client)
return client.build_auth_request_uri(
response_type=response_type,
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
prompt=prompt,
scope=decorate_scope(scopes, self.client_id),
nonce=nonce,
domain_hint=domain_hint,
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge),
)
warnings.warn(
"Change your get_authorization_request_url() "
"to initiate_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return client.build_auth_request_uri(
response_type=response_type,
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
prompt=prompt,
scope=decorate_scope(scopes, self.client_id),
nonce=nonce,
domain_hint=domain_hint,
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge),
)

def acquire_token_by_auth_code_flow(
self, auth_code_flow, auth_response, scopes=None, **kwargs):
Expand Down Expand Up @@ -570,20 +576,24 @@ def acquire_token_by_authorization_code(
# really empty.
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs)
warnings.warn(
"Change your acquire_token_by_authorization_code() "
"to acquire_token_by_auth_code_flow()", DeprecationWarning)
with warnings.catch_warnings(record=True):
return self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
scope=decorate_scope(scopes, self.client_id),
headers={
CLIENT_REQUEST_ID: _get_new_correlation_id(),
CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header(
self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID),
},
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs)

def get_accounts(self, username=None):
"""Get a list of accounts which previously signed in, i.e. exists in cache.
Expand Down Expand Up @@ -942,7 +952,7 @@ def _validate_ssh_cert_input_data(self, data):
"you must include a string parameter named 'key_id' "
"which identifies the key in the 'req_cnf' argument.")

def acquire_token_by_refresh_token(self, refresh_token, scopes):
def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
"""Acquire token(s) based on a refresh token (RT) obtained from elsewhere.
You use this method only when you have old RTs from elsewhere,
Expand All @@ -965,6 +975,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes):
* A dict contains "error" and some other keys, when error happened.
* A dict contains no "error" key means migration was successful.
"""
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_refresh_token(
refresh_token,
scope=decorate_scope(scopes, self.client_id),
Expand All @@ -976,7 +987,7 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes):
rt_getter=lambda rt: rt,
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
)
**kwargs)


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down Expand Up @@ -1233,6 +1244,7 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_for_client(
scope=scopes, # This grant flow requires no scope decoration
headers={
Expand Down Expand Up @@ -1294,4 +1306,3 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
self.ACQUIRE_TOKEN_ON_BEHALF_OF_ID),
},
**kwargs)

12 changes: 11 additions & 1 deletion msal/oauth2cli/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@

logger = logging.getLogger(__name__)


def _str2bytes(raw):
# A conversion based on duck-typing rather than six.text_type
try: # Assuming it is a string
return raw.encode(encoding="utf-8")
except: # Otherwise we treat it as bytes and return it as-is
return raw


class AssertionCreator(object):
def create_normal_assertion(
self, audience, issuer, subject, expires_at=None, expires_in=600,
Expand Down Expand Up @@ -103,8 +112,9 @@ def create_normal_assertion(
payload['nbf'] = not_before
payload.update(additional_claims or {})
try:
return jwt.encode(
str_or_bytes = jwt.encode( # PyJWT 1 returns bytes, PyJWT 2 returns str
payload, self.key, algorithm=self.algorithm, headers=self.headers)
return _str2bytes(str_or_bytes) # We normalize them into bytes
except:
if self.algorithm.startswith("RS") or self.algorithm.starswith("ES"):
logger.exception(
Expand Down
53 changes: 36 additions & 17 deletions msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,9 @@ def obtain_auth_code(listen_port, auth_uri=None): # Historically only used in t
).get("code")


def _browse(auth_uri):
def _browse(auth_uri): # throws ImportError, possibly webbrowser.Error in future
import webbrowser # Lazy import. Some distro may not have this.
controller = webbrowser.get() # Get a default controller
# Some Linux Distro does not setup default browser properly,
# so we try to explicitly use some popular browser, if we found any.
for browser in ["chrome", "firefox", "safari", "windows-default"]:
try:
controller = webbrowser.get(browser)
break
except webbrowser.Error:
pass # This browser is not installed. Try next one.
logger.info("Please open a browser on THIS device to visit: %s" % auth_uri)
controller.open(auth_uri)
return webbrowser.open(auth_uri) # Use default browser. Customizable by $BROWSER


def _qs2kv(qs):
Expand Down Expand Up @@ -130,14 +120,16 @@ def get_port(self):
return self._server.server_address[1]

def get_auth_response(self, auth_uri=None, timeout=None, state=None,
welcome_template=None, success_template=None, error_template=None):
"""Wait and return the auth response, or None when timeout.
welcome_template=None, success_template=None, error_template=None,
auth_uri_callback=None,
):
"""Wait and return the auth response. Raise RuntimeError when timeout.
:param str auth_uri:
If provided, this function will try to open a local browser.
:param int timeout: In seconds. None means wait indefinitely.
:param str state:
You may provide the state you used in auth_url,
You may provide the state you used in auth_uri,
then we will use it to validate incoming response.
:param str welcome_template:
If provided, your end user will see it instead of the auth_uri.
Expand All @@ -152,6 +144,10 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
The page will be displayed when authentication encountered error.
Placeholders can be any of these:
https://tools.ietf.org/html/rfc6749#section-5.2
:param callable auth_uri_callback:
A function with the shape of lambda auth_uri: ...
When a browser was unable to be launch, this function will be called,
so that the app could tell user to manually visit the auth_uri.
:return:
The auth response of the first leg of Auth Code flow,
typically {"code": "...", "state": "..."} or {"error": "...", ...}
Expand All @@ -164,8 +160,31 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
logger.debug("Abort by visit %s", abort_uri)
self._server.welcome_page = Template(welcome_template or "").safe_substitute(
auth_uri=auth_uri, abort_uri=abort_uri)
if auth_uri:
_browse(welcome_uri if welcome_template else auth_uri)
if auth_uri: # Now attempt to open a local browser to visit it
_uri = welcome_uri if welcome_template else auth_uri
logger.info("Open a browser on this device to visit: %s" % _uri)
browser_opened = False
try:
browser_opened = _browse(_uri)
except: # Had to use broad except, because the potential
# webbrowser.Error is purposely undefined outside of _browse().
# Absorb and proceed. Because browser could be manually run elsewhere.
logger.exception("_browse(...) unsuccessful")
if not browser_opened:
if not auth_uri_callback:
logger.warning(
"Found no browser in current environment. "
"If this program is being run inside a container "
"which has access to host network "
"(i.e. started by `docker run --net=host -it ...`), "
"you can use browser on host to visit the following link. "
"Otherwise, this auth attempt would either timeout "
"(current timeout setting is {timeout}) "
"or be aborted by CTRL+C. Auth URI: {auth_uri}".format(
auth_uri=_uri, timeout=timeout))
else: # Then it is the auth_uri_callback()'s job to inform the user
auth_uri_callback(_uri)

self._server.success_template = Template(success_template or
"Authentication completed. You can close this window now.")
self._server.error_template = Template(error_template or
Expand Down
10 changes: 7 additions & 3 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
client_secret (str): Triggers HTTP AUTH for Confidential Client
client_assertion (bytes, callable):
The client assertion to authenticate this client, per RFC 7521.
It can be a raw SAML2 assertion (this method will encode it for you),
or a raw JWT assertion.
It can be a raw SAML2 assertion (we will base64 encode it for you),
or a raw JWT assertion in bytes (which we will relay to http layer).
It can also be a callable (recommended),
so that we will do lazy creation of an assertion.
client_assertion_type (str):
Expand Down Expand Up @@ -198,7 +198,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
self.default_body["client_assertion_type"], lambda a: a)
_data["client_assertion"] = encoder(
self.client_assertion() # Do lazy on-the-fly computation
if callable(self.client_assertion) else self.client_assertion)
if callable(self.client_assertion) else self.client_assertion
) # The type is bytes, which is preferrable. See also:
# https://github.com/psf/requests/issues/4503#issuecomment-455001070

_data.update(self.default_body) # It may contain authen parameters
_data.update(data or {}) # So the content in data param prevails
Expand Down Expand Up @@ -578,6 +580,7 @@ def obtain_token_by_browser(
welcome_template=None,
success_template=None,
auth_params=None,
auth_uri_callback=None,
**kwargs):
"""A native app can use this method to obtain token via a local browser.
Expand Down Expand Up @@ -635,6 +638,7 @@ def obtain_token_by_browser(
timeout=timeout,
welcome_template=welcome_template,
success_template=success_template,
auth_uri_callback=auth_uri_callback,
)
except PermissionError:
if 0 < listen_port < 1024:
Expand Down
Loading

0 comments on commit 72a7250

Please sign in to comment.