Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: decorator to validate pt.Model type hints #38

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/patito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from patito.exceptions import ValidationError
from patito.polars import DataFrame, LazyFrame
from patito.pydantic import Field, Model
from patito.decorators import validate_hints

_CACHING_AVAILABLE = False
_DUCKDB_AVAILABLE = False
Expand All @@ -23,6 +24,7 @@
"exceptions",
"field",
"sql",
"validate_hints"
]

try:
Expand Down
38 changes: 38 additions & 0 deletions src/patito/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import inspect
import typing
from functools import wraps
from typing import Any, Callable, TypeVar

import patito as pt


T = TypeVar("T")

def validate_hints(wrapped: Callable[..., T]) -> Callable[..., T]:
"""Validate function arguments and return on pt.Model type hints

:param wrapped: the function to decorate
"""
def _validate_or_skip(validator: Any, target: Any) -> None:
if not isinstance(validator, pt.pydantic.ModelMetaclass):
return

validator.validate(target)

@wraps(wrapped)
def wrapper(*args, **kwargs) -> T:
type_hints = typing.get_type_hints(wrapped)
signature = inspect.signature(wrapped)

for arg_label, arg in zip(signature.parameters.keys(), args):
if arg_label in type_hints:
_validate_or_skip(type_hints[arg_label], arg)

result = wrapped(*args, **kwargs)

if "return" in type_hints:
_validate_or_skip(type_hints["return"], result)

return result

return wrapper
52 changes: 52 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for patito.decorators"""
import pytest
import patito as pt
import polars as pl


def test_validate_hints_arg_validation_pass():
class MyModel(pt.Model):
a: int

@pt.validate_hints
def func(arg: MyModel) -> None:
pass

polars_dataframe = pl.DataFrame({"a": [1]})
func(polars_dataframe)


def test_validate_hints_arg_validation_fail():
class MyModel(pt.Model):
a: int

@pt.validate_hints
def func(arg: MyModel) -> None:
pass

polars_dataframe = pl.DataFrame({"a": ["b"]})
with pytest.raises(pt.ValidationError):
func(polars_dataframe)


def test_validate_hints_return_validation_pass():
class MyModel(pt.Model):
a: int

@pt.validate_hints
def func() -> MyModel:
return pl.DataFrame({"a": [1]})

func()


def test_validate_hints_return_validation_fail():
class MyModel(pt.Model):
a: int

@pt.validate_hints
def func() -> MyModel:
return pl.DataFrame({"a": ["b"]})

with pytest.raises(pt.ValidationError):
func()