Skip to content

Commit

Permalink
Use a TypeVar to pass the type through packaging.specifiers.filter() (#…
Browse files Browse the repository at this point in the history
…430)

This passes through the type of a list in filter.

Example:

    x: typing.List[Version]
    reveal_type(SpecifierSet("*").filter(x))

Before:

    Revealed type is 'typing.Iterable[Union[packaging.version.Version, packaging.version.LegacyVersion, builtins.str]]'

After:

    Revealed type is 'typing.Iterable[packaging.version.Version*]'

Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii authored May 28, 2021
1 parent 8c78c35 commit 7350746
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions packaging/specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Pattern,
Set,
Tuple,
TypeVar,
Union,
)

Expand All @@ -25,6 +26,7 @@

ParsedVersion = Union[Version, LegacyVersion]
UnparsedVersion = Union[Version, LegacyVersion, str]
VersionTypeVar = TypeVar("VersionTypeVar", bound=UnparsedVersion)
CallableOperator = Callable[[ParsedVersion, str], bool]


Expand Down Expand Up @@ -84,8 +86,8 @@ def contains(self, item: str, prereleases: Optional[bool] = None) -> bool:

@abc.abstractmethod
def filter(
self, iterable: Iterable[UnparsedVersion], prereleases: Optional[bool] = None
) -> Iterable[UnparsedVersion]:
self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None
) -> Iterable[VersionTypeVar]:
"""
Takes an iterable of items and filters them so that only items which
are contained within this specifier are allowed in it.
Expand Down Expand Up @@ -205,8 +207,8 @@ def contains(
return operator_callable(normalized_item, self.version)

def filter(
self, iterable: Iterable[UnparsedVersion], prereleases: Optional[bool] = None
) -> Iterable[UnparsedVersion]:
self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None
) -> Iterable[VersionTypeVar]:

yielded = False
found_prereleases = []
Expand Down Expand Up @@ -773,8 +775,8 @@ def contains(
return all(s.contains(item, prereleases=prereleases) for s in self._specs)

def filter(
self, iterable: Iterable[UnparsedVersion], prereleases: Optional[bool] = None
) -> Iterable[UnparsedVersion]:
self, iterable: Iterable[VersionTypeVar], prereleases: Optional[bool] = None
) -> Iterable[VersionTypeVar]:

# Determine if we're forcing a prerelease or not, if we're not forcing
# one for this particular filter call, then we'll use whatever the
Expand All @@ -793,8 +795,11 @@ def filter(
# which will filter out any pre-releases, unless there are no final
# releases, and which will filter out LegacyVersion in general.
else:
filtered: List[UnparsedVersion] = []
found_prereleases: List[UnparsedVersion] = []
filtered: List[VersionTypeVar] = []
found_prereleases: List[VersionTypeVar] = []

item: UnparsedVersion
parsed_version: Union[Version, LegacyVersion]

for item in iterable:
# Ensure that we some kind of Version class for this item.
Expand Down

0 comments on commit 7350746

Please sign in to comment.