Skip to content

Commit

Permalink
Add support for binary union types - Python 3.10 (#1977)
Browse files Browse the repository at this point in the history
Co-authored-by: Pierre Sassoulas <[email protected]>
  • Loading branch information
cdce8p and Pierre-Sassoulas authored Jan 30, 2023
1 parent 0545192 commit 156db06
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 4 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ What's New in astroid 2.14.0?
=============================
Release date: TBA

* Add support for inferring binary union types added in Python 3.10.

Refs PyCQA/pylint#8119


What's New in astroid 2.13.4?
Expand Down
43 changes: 41 additions & 2 deletions astroid/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ def __init__(
if proxied is None:
# This is a hack to allow calling this __init__ during bootstrapping of
# builtin classes and their docstrings.
# For Const and Generator nodes the _proxied attribute is set during bootstrapping
# For Const, Generator, and UnionType nodes the _proxied attribute
# is set during bootstrapping
# as we first need to build the ClassDef that they can proxy.
# Thus, if proxied is None self should be a Const or Generator
# as that is the only way _proxied will be correctly set as a ClassDef.
assert isinstance(self, (nodes.Const, Generator))
assert isinstance(self, (nodes.Const, Generator, UnionType))
else:
self._proxied = proxied

Expand Down Expand Up @@ -669,3 +670,41 @@ def __repr__(self) -> str:

def __str__(self) -> str:
return f"AsyncGenerator({self._proxied.name})"


class UnionType(BaseInstance):
"""Special node representing new style typing unions.
Proxied class is set once for all in raw_building.
"""

_proxied: nodes.ClassDef

def __init__(
self,
left: UnionType | nodes.ClassDef | nodes.Const,
right: UnionType | nodes.ClassDef | nodes.Const,
parent: nodes.NodeNG | None = None,
) -> None:
super().__init__()
self.parent = parent
self.left = left
self.right = right

def callable(self) -> Literal[False]:
return False

def bool_value(self, context: InferenceContext | None = None) -> Literal[True]:
return True

def pytype(self) -> Literal["types.UnionType"]:
return "types.UnionType"

def display_type(self) -> str:
return "UnionType"

def __repr__(self) -> str:
return f"<UnionType({self._proxied.name}) l.{self.lineno} at 0x{id(self)}>"

def __str__(self) -> str:
return f"UnionType({self._proxied.name})"
25 changes: 25 additions & 0 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

from astroid import bases, constraint, decorators, helpers, nodes, protocols, util
from astroid.const import PY310_PLUS
from astroid.context import (
CallContext,
InferenceContext,
Expand Down Expand Up @@ -758,6 +759,14 @@ def _bin_op(
)


def _bin_op_or_union_type(
left: bases.UnionType | nodes.ClassDef | nodes.Const,
right: bases.UnionType | nodes.ClassDef | nodes.Const,
) -> Generator[InferenceResult, None, None]:
"""Create a new UnionType instance for binary or, e.g. int | str."""
yield bases.UnionType(left, right)


def _get_binop_contexts(context, left, right):
"""Get contexts for binary operations.
Expand Down Expand Up @@ -817,6 +826,22 @@ def _get_binop_flow(
_bin_op(left, binary_opnode, op, right, context),
_bin_op(right, binary_opnode, op, left, reverse_context, reverse=True),
]

if (
PY310_PLUS
and op == "|"
and (
isinstance(left, (bases.UnionType, nodes.ClassDef))
or isinstance(left, nodes.Const)
and left.value is None
)
and (
isinstance(right, (bases.UnionType, nodes.ClassDef))
or isinstance(right, nodes.Const)
and right.value is None
)
):
methods.extend([functools.partial(_bin_op_or_union_type, left, right)])
return methods


Expand Down
18 changes: 18 additions & 0 deletions astroid/raw_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,24 @@ def _astroid_bootstrapping() -> None:
)
bases.AsyncGenerator._proxied = _AsyncGeneratorType
builder.object_build(bases.AsyncGenerator._proxied, types.AsyncGeneratorType)

if hasattr(types, "UnionType"):
_UnionTypeType = nodes.ClassDef(types.UnionType.__name__)
_UnionTypeType.parent = astroid_builtin
union_type_doc_node = (
nodes.Const(value=types.UnionType.__doc__)
if types.UnionType.__doc__
else None
)
_UnionTypeType.postinit(
bases=[],
body=[],
decorators=None,
doc_node=union_type_doc_node,
)
bases.UnionType._proxied = _UnionTypeType
builder.object_build(bases.UnionType._proxied, types.UnionType)

builtin_types = (
types.GetSetDescriptorType,
types.GeneratorType,
Expand Down
118 changes: 116 additions & 2 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from astroid import decorators as decoratorsmod
from astroid import helpers, nodes, objects, test_utils, util
from astroid.arguments import CallSite
from astroid.bases import BoundMethod, Instance, UnboundMethod
from astroid.bases import BoundMethod, Instance, UnboundMethod, UnionType
from astroid.builder import AstroidBuilder, _extract_single_node, extract_node, parse
from astroid.const import IS_PYPY, PY38_PLUS, PY39_PLUS
from astroid.const import IS_PYPY, PY38_PLUS, PY39_PLUS, PY310_PLUS
from astroid.context import InferenceContext
from astroid.exceptions import (
AstroidTypeError,
Expand Down Expand Up @@ -1209,6 +1209,120 @@ def randint(maximum):
],
)

def test_binary_op_or_union_type(self) -> None:
"""Binary or union is only defined for Python 3.10+."""
code = """
class A: ...
int | 2 #@
int | "Hello" #@
int | ... #@
int | A() #@
int | None | 2 #@
"""
ast_nodes = extract_node(code)
for n in ast_nodes:
assert n.inferred() == [util.Uninferable]

code = """
from typing import List
class A: ...
class B: ...
int | None #@
int | str #@
int | str | None #@
A | B #@
A | None #@
List[int] | int #@
tuple | int #@
"""
ast_nodes = extract_node(code)
if not PY310_PLUS:
for n in ast_nodes:
assert n.inferred() == [util.Uninferable]
else:
i0 = ast_nodes[0].inferred()[0]
assert isinstance(i0, UnionType)
assert isinstance(i0.left, nodes.ClassDef)
assert i0.left.name == "int"
assert isinstance(i0.right, nodes.Const)
assert i0.right.value is None

# Assert basic UnionType properties and methods
assert i0.callable() is False
assert i0.bool_value() is True
assert i0.pytype() == "types.UnionType"
assert i0.display_type() == "UnionType"
assert str(i0) == "UnionType(UnionType)"
assert repr(i0) == f"<UnionType(UnionType) l.None at 0x{id(i0)}>"

i1 = ast_nodes[1].inferred()[0]
assert isinstance(i1, UnionType)

i2 = ast_nodes[2].inferred()[0]
assert isinstance(i2, UnionType)
assert isinstance(i2.left, UnionType)
assert isinstance(i2.left.left, nodes.ClassDef)
assert i2.left.left.name == "int"
assert isinstance(i2.left.right, nodes.ClassDef)
assert i2.left.right.name == "str"
assert isinstance(i2.right, nodes.Const)
assert i2.right.value is None

i3 = ast_nodes[3].inferred()[0]
assert isinstance(i3, UnionType)
assert isinstance(i3.left, nodes.ClassDef)
assert i3.left.name == "A"
assert isinstance(i3.right, nodes.ClassDef)
assert i3.right.name == "B"

i4 = ast_nodes[4].inferred()[0]
assert isinstance(i4, UnionType)

i5 = ast_nodes[5].inferred()[0]
assert isinstance(i5, UnionType)
assert isinstance(i5.left, nodes.ClassDef)
assert i5.left.name == "List"

i6 = ast_nodes[6].inferred()[0]
assert isinstance(i6, UnionType)
assert isinstance(i6.left, nodes.ClassDef)
assert i6.left.name == "tuple"

code = """
from typing import List
Alias1 = List[int]
Alias2 = str | int
Alias1 | int #@
Alias2 | int #@
Alias1 | Alias2 #@
"""
ast_nodes = extract_node(code)
if not PY310_PLUS:
for n in ast_nodes:
assert n.inferred() == [util.Uninferable]
else:
i0 = ast_nodes[0].inferred()[0]
assert isinstance(i0, UnionType)
assert isinstance(i0.left, nodes.ClassDef)
assert i0.left.name == "List"

i1 = ast_nodes[1].inferred()[0]
assert isinstance(i1, UnionType)
assert isinstance(i1.left, UnionType)
assert isinstance(i1.left.left, nodes.ClassDef)
assert i1.left.left.name == "str"

i2 = ast_nodes[2].inferred()[0]
assert isinstance(i2, UnionType)
assert isinstance(i2.left, nodes.ClassDef)
assert i2.left.name == "List"
assert isinstance(i2.right, UnionType)

def test_nonregr_lambda_arg(self) -> None:
code = """
def f(g = lambda: None):
Expand Down

0 comments on commit 156db06

Please sign in to comment.