diff --git a/ChangeLog b/ChangeLog index 317d69e3e5..4b95f224fe 100644 --- a/ChangeLog +++ b/ChangeLog @@ -12,6 +12,8 @@ What's New in astroid 2.13.5? ============================= Release date: TBA +* Revert ``CallContext`` change as it caused a ``RecursionError`` regression. + What's New in astroid 2.13.4? diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index 6a13407222..b11bfa1965 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -28,6 +28,7 @@ Const, JoinedStr, Name, + NodeNG, Subscript, Tuple, ) @@ -379,6 +380,36 @@ def infer_special_alias( return iter([class_def]) +def _looks_like_typing_cast(node: Call) -> bool: + return isinstance(node, Call) and ( + isinstance(node.func, Name) + and node.func.name == "cast" + or isinstance(node.func, Attribute) + and node.func.attrname == "cast" + ) + + +def infer_typing_cast( + node: Call, ctx: context.InferenceContext | None = None +) -> Iterator[NodeNG]: + """Infer call to cast() returning same type as casted-from var.""" + if not isinstance(node.func, (Name, Attribute)): + raise UseInferenceDefault + + try: + func = next(node.func.infer(context=ctx)) + except (InferenceError, StopIteration) as exc: + raise UseInferenceDefault from exc + if ( + not isinstance(func, FunctionDef) + or func.qname() != "typing.cast" + or len(node.args) != 2 + ): + raise UseInferenceDefault + + return node.args[1].infer(context=ctx) + + AstroidManager().register_transform( Call, inference_tip(infer_typing_typevar_or_newtype), @@ -387,6 +418,9 @@ def infer_special_alias( AstroidManager().register_transform( Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript ) +AstroidManager().register_transform( + Call, inference_tip(infer_typing_cast), _looks_like_typing_cast +) if PY39_PLUS: AstroidManager().register_transform( diff --git a/astroid/context.py b/astroid/context.py index 81b02f11c4..b469964805 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -161,14 +161,13 @@ def __str__(self) -> str: class CallContext: """Holds information for a call site.""" - __slots__ = ("args", "keywords", "callee", "parent_call_context") + __slots__ = ("args", "keywords", "callee") def __init__( self, args: list[NodeNG], keywords: list[Keyword] | None = None, callee: NodeNG | None = None, - parent_call_context: CallContext | None = None, ): self.args = args # Call positional arguments if keywords: @@ -177,9 +176,6 @@ def __init__( arg_value_pairs = [] self.keywords = arg_value_pairs # Call keyword arguments self.callee = callee # Function being called - self.parent_call_context = ( - parent_call_context # Parent CallContext for nested calls - ) def copy_context(context: InferenceContext | None) -> InferenceContext: diff --git a/astroid/inference.py b/astroid/inference.py index 59bc4eca56..e8fec289fa 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -273,10 +273,7 @@ def infer_call( try: if hasattr(callee, "infer_call_result"): callcontext.callcontext = CallContext( - args=self.args, - keywords=self.keywords, - callee=callee, - parent_call_context=callcontext.callcontext, + args=self.args, keywords=self.keywords, callee=callee ) yield from callee.infer_call_result(caller=self, context=callcontext) except InferenceError: diff --git a/astroid/protocols.py b/astroid/protocols.py index 48f0cd0f09..72549b7952 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -470,7 +470,7 @@ def arguments_assigned_stmts( # reset call context/name callcontext = context.callcontext context = copy_context(context) - context.callcontext = callcontext.parent_call_context + context.callcontext = None args = arguments.CallSite(callcontext, context=context) return args.infer_argument(self.parent, node_name, context) return _arguments_infer_argname(self, node_name, context) diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 30594c385f..dd929bd0de 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -2160,6 +2160,13 @@ class A: assert inferred.value == 42 def test_typing_cast_multiple_inference_calls(self) -> None: + """Inference of an outer function should not store the result for cast. + + https://github.com/PyCQA/pylint/issues/8074 + + Possible solution caused RecursionErrors with Python 3.8 and CPython + PyPy. + https://github.com/PyCQA/astroid/pull/1982 + """ ast_nodes = builder.extract_node( """ from typing import TypeVar, cast @@ -2177,7 +2184,7 @@ def ident(var: T) -> T: i1 = next(ast_nodes[1].infer()) assert isinstance(i1, nodes.Const) - assert i1.value == "Hello" + assert i1.value == 2 # should be "Hello"! @pytest.mark.skipif( diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index e66103d978..86443d895f 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -23,7 +23,7 @@ from astroid.arguments import CallSite from astroid.bases import BoundMethod, Instance, UnboundMethod from astroid.builder import AstroidBuilder, _extract_single_node, extract_node, parse -from astroid.const import IS_PYPY, PY38_PLUS, PY39_PLUS +from astroid.const import PY38_PLUS, PY39_PLUS from astroid.context import InferenceContext from astroid.exceptions import ( AstroidTypeError, @@ -6820,9 +6820,6 @@ def test_imported_module_var_inferable3() -> None: assert i_w_val.as_string() == "['w', 'v']" -@pytest.mark.skipif( - IS_PYPY, reason="Test run with coverage on PyPy sometimes raises a RecursionError" -) def test_recursion_on_inference_tip() -> None: """Regression test for recursion in inference tip. diff --git a/tests/unittest_inference_calls.py b/tests/unittest_inference_calls.py index 84a611d3a4..72afb9898c 100644 --- a/tests/unittest_inference_calls.py +++ b/tests/unittest_inference_calls.py @@ -146,6 +146,8 @@ def g(y): def test_inner_call_with_dynamic_argument() -> None: """Test function where return value is the result of a separate function call, with a dynamic value passed to the inner function. + + Currently, this is Uninferable. """ node = builder.extract_node( """ @@ -161,8 +163,7 @@ def g(y): assert isinstance(node, nodes.NodeNG) inferred = node.inferred() assert len(inferred) == 1 - assert isinstance(inferred[0], nodes.Const) - assert inferred[0].value == 3 + assert inferred[0] is Uninferable def test_method_const_instance_attr() -> None: