Skip to content

Commit

Permalink
Add dataclasses for sign extension
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Oct 25, 2024
1 parent 3703f2c commit 4c0b217
Showing 1 changed file with 63 additions and 36 deletions.
99 changes: 63 additions & 36 deletions fido2/ctap2/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from .. import cbor
from enum import Enum, unique
from dataclasses import dataclass
from typing import Dict, Tuple, Any, Optional, Mapping
from typing import Dict, Tuple, Any, Optional, Mapping, Sequence
import abc
import warnings

Expand Down Expand Up @@ -166,7 +166,7 @@ class _PrfValues(_JsonDataObject):
@dataclass(eq=False, frozen=True)
class _PrfInputs(_JsonDataObject):
eval: Optional[_PrfValues] = None
evalByCredential: Optional[Mapping[str, _PrfValues]] = None
eval_by_credential: Optional[Mapping[str, _PrfValues]] = None


@dataclass(eq=False, frozen=True)
Expand Down Expand Up @@ -209,11 +209,10 @@ def process_get_input(self, inputs):
if not self.is_supported():
return

data = inputs.get("prf")
if data:
prf = _PrfInputs.from_dict(data)
prf = _PrfInputs.from_dict(inputs.get("prf"))
if prf:
secrets = prf.eval
by_creds = prf.evalByCredential
by_creds = prf.eval_by_credential
if by_creds:
# Make sure all keys are valid IDs from allow_credentials
allow_list = self._get_options.allow_credentials
Expand All @@ -236,11 +235,10 @@ def process_get_input(self, inputs):
)
self.prf = True
else:
data = inputs.get("hmacGetSecret")
if not data or not self._allow_hmac_secret:
get_secret = _HmacGetSecretInput.from_dict(inputs.get("hmacGetSecret"))
if not get_secret or not self._allow_hmac_secret:
return
res = _HmacGetSecretInput.from_dict(data)
salts = res.salt1, res.salt2 or b""
salts = get_secret.salt1, get_secret.salt2 or b""
self.prf = False

if not (
Expand Down Expand Up @@ -302,7 +300,7 @@ def is_supported(self):
return super().is_supported() and self.ctap.info.options.get("largeBlobs")

def process_create_input(self, inputs):
data = _LargeBlobInputs.from_dict(inputs.get("largeBlob", {}))
data = _LargeBlobInputs.from_dict(inputs.get("largeBlob"))
if data:
if data.read or data.write:
raise ValueError("Invalid set of parameters")
Expand All @@ -318,12 +316,13 @@ def process_create_output(self, attestation_response, *args):
}

def get_get_permissions(self, inputs):
if _LargeBlobInputs.from_dict(inputs.get("largeBlob", {})).write:
data = _LargeBlobInputs.from_dict(inputs.get("largeBlob"))
if data and data.write:
return ClientPin.PERMISSION.LARGE_BLOB_WRITE
return ClientPin.PERMISSION(0)

def process_get_input(self, inputs):
data = _LargeBlobInputs.from_dict(inputs.get("largeBlob", {}))
data = _LargeBlobInputs.from_dict(inputs.get("largeBlob"))
if data:
if data.support or (data.read and data.write):
raise ValueError("Invalid set of parameters")
Expand All @@ -333,7 +332,7 @@ def process_get_input(self, inputs):
self._action = True
else:
self._action = data.write
return True if data else None
return True

def process_get_output(self, assertion_response, token, pin_protocol):
blob_key = assertion_response.large_blob_key
Expand Down Expand Up @@ -438,6 +437,36 @@ def process_create_output(self, attestation_response, *args):
return {"credProps": _CredPropsOutputs(rk=rk)}


@dataclass(eq=False, frozen=True)
class _SignGenerateKeyInputs(_JsonDataObject):
algorithms: Sequence[int]
ph_data: Optional[bytes] = None


@dataclass(eq=False, frozen=True)
class _SignSignInputs(_JsonDataObject):
ph_data: bytes
key_handle_by_credential: Mapping[str, bytes]


@dataclass(eq=False, frozen=True)
class _SignInputs(_JsonDataObject):
generate_key: Optional[_SignGenerateKeyInputs] = None
sign: Optional[_SignSignInputs] = None


@dataclass(eq=False, frozen=True)
class _SignGeneratedKey(_JsonDataObject):
public_key: bytes
key_handle: bytes


@dataclass(eq=False, frozen=True)
class _SignOutputs(_JsonDataObject):
generated_key: Optional[_SignGeneratedKey] = None
signature: Optional[bytes] = None


class SignExtension(Ctap2Extension):
"""
Implements the sign CTAP2 extension.
Expand All @@ -446,14 +475,14 @@ class SignExtension(Ctap2Extension):
NAME = "sign"

def process_create_input(self, inputs):
data = inputs.get("sign", {})
data = _SignInputs.from_dict(inputs.get("sign"))
if not data or not self.is_supported():
return

if "sign" in data or "generateKey" not in data:
if data.sign or not data.generate_key:
raise ValueError("Invalid inputs")

gk = data["generateKey"]
gk = data.generate_key

selection = (
self._create_options.authenticator_selection
Expand All @@ -464,10 +493,10 @@ def process_create_input(self, inputs):
if selection.user_verification == UserVerificationRequirement.REQUIRED
else 0b001
)
outputs = {3: gk["algorithms"], 4: flags}
outputs = {3: gk.algorithms, 4: flags}

if "phData" in gk:
outputs[0] = gk["phData"]
if gk.pd_data:
outputs[0] = gk.ph_data

return outputs

Expand All @@ -478,28 +507,26 @@ def process_create_output(self, attestation_response, *args):
assert cred_data is not None # nosec
pk = cred_data.public_key

output = {
"generatedKey": {
"publicKey": cbor.encode(pk),
"keyHandle": cbor.encode(pk.get_ref()),
}
return {
"sign": _SignOutputs(
generated_key=_SignGeneratedKey(
public_key=cbor.encode(pk),
key_handle=cbor.encode(pk.get_ref()),
),
signature=data.get(6),
)
}

if 6 in data:
output["signature"] = data[6]

return {"sign": output}

def process_get_input(self, inputs):
data = inputs.get("sign", {})
data = _SignInputs.from_dict(inputs.get("sign"))
if not data or not self.is_supported():
return

if "sign" not in data or "generateKey" in data:
if not data.sign or data.generate_key:
raise ValueError("Invalid inputs")

sign = data["sign"]
by_creds = sign["keyHandleByCredential"]
sign = data.sign
by_creds = sign.key_handle_by_credential

# Make sure all keys are valid IDs from allow_credentials
allow_list = self._get_options.allow_credentials
Expand All @@ -513,11 +540,11 @@ def process_get_input(self, inputs):
kh = by_creds[websafe_encode(self._selected.id)]

return {
0: sign["phData"],
0: sign.phData,
5: [kh],
}

def process_get_output(self, assertion_response, *args):
data = assertion_response.auth_data.extensions.get(self.NAME)

return {"sign": {"signature": data[6]}}
return {"sign": _SignOutputs(signature=data[6])}

0 comments on commit 4c0b217

Please sign in to comment.