diff --git a/examples/calc.py b/examples/calc.py index 9e9aa78f..3474f34c 100644 --- a/examples/calc.py +++ b/examples/calc.py @@ -46,6 +46,7 @@ class CalculateTree(Transformer): number = float def __init__(self): + super().__init__() self.vars = {} def assign_var(self, name, value): diff --git a/lark/__init__.py b/lark/__init__.py index cd9ea9f6..0552bf85 100644 --- a/lark/__init__.py +++ b/lark/__init__.py @@ -1,6 +1,6 @@ from .utils import logger from .tree import Tree -from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive +from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive, call_for from .exceptions import (ParseError, LexError, GrammarError, UnexpectedToken, UnexpectedInput, UnexpectedCharacters, UnexpectedEOF, LarkError) from .lexer import Token diff --git a/lark/load_grammar.py b/lark/load_grammar.py index 7652b442..7a74e7e7 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -195,6 +195,7 @@ class FindRuleSize(Transformer): def __init__(self, keep_all_tokens): + super().__init__() self.keep_all_tokens = keep_all_tokens def _will_not_get_removed(self, sym): @@ -225,6 +226,7 @@ def expansions(self, args): @inline_args class EBNF_to_BNF(Transformer_InPlace): def __init__(self): + super().__init__() self.new_rules = [] self.rules_cache = {} self.prefix = 'anon' @@ -440,6 +442,7 @@ class PrepareAnonTerminals(Transformer_InPlace): """Create a unique list of anonymous terminals. Attempt to give meaningful names to them when we add them""" def __init__(self, terminals): + super().__init__() self.terminals = terminals self.term_set = {td.name for td in self.terminals} self.term_reverse = {td.pattern: td for td in terminals} @@ -495,6 +498,7 @@ class _ReplaceSymbols(Transformer_InPlace): """Helper for ApplyTemplates""" def __init__(self): + super().__init__() self.names = {} def value(self, c): @@ -513,6 +517,7 @@ class ApplyTemplates(Transformer_InPlace): """Apply the templates, creating new rules that represent the used templates""" def __init__(self, rule_defs): + super().__init__() self.rule_defs = rule_defs self.replacer = _ReplaceSymbols() self.created_templates = set() diff --git a/lark/reconstruct.py b/lark/reconstruct.py index 02b49476..5acff709 100644 --- a/lark/reconstruct.py +++ b/lark/reconstruct.py @@ -27,6 +27,7 @@ class WriteTokensTransformer(Transformer_InPlace): term_subs: Dict[str, Callable[[Symbol], str]] def __init__(self, tokens: Dict[str, TerminalDef], term_subs: Dict[str, Callable[[Symbol], str]]) -> None: + super().__init__() self.tokens = tokens self.term_subs = term_subs diff --git a/lark/tools/nearley.py b/lark/tools/nearley.py index 8b8ef89e..394a64d0 100644 --- a/lark/tools/nearley.py +++ b/lark/tools/nearley.py @@ -53,6 +53,7 @@ def _get_rulename(name): @v_args(inline=True) class NearleyToLark(Transformer): def __init__(self): + super().__init__() self._count = 0 self.extra_rules = {} self.extra_rules_rev = {} diff --git a/lark/tree_templates.py b/lark/tree_templates.py index 8db8b35c..3ea1ca5a 100644 --- a/lark/tree_templates.py +++ b/lark/tree_templates.py @@ -76,6 +76,7 @@ def _match_tree_template(self, template, tree): class _ReplaceVars(Transformer): def __init__(self, conf, vars): + super().__init__() self._conf = conf self._vars = vars diff --git a/lark/visitors.py b/lark/visitors.py index fd49ca02..52a2d693 100644 --- a/lark/visitors.py +++ b/lark/visitors.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any +from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any, Dict from abc import ABC from .utils import combine_alternatives @@ -9,6 +9,7 @@ ###{standalone from functools import wraps, update_wrapper from inspect import getmembers, getmro +import warnings _T = TypeVar('_T') _R = TypeVar('_R') @@ -35,6 +36,49 @@ def __repr__(self): Discard = _DiscardType() +# User function lookup and overrides + + +def call_for(rule_name: str) -> Callable: + def _call_for(func: Callable) -> Callable: + func._rule_name = rule_name + return func + + return _call_for + + +class _UserFuncLookup: + _user_func_overrides: Dict[str, Callable] + + def __init__(self): + self._user_func_overrides = {} + + def _look_up_user_func(self, rule_name: str) -> Optional[Callable]: + user_func = getattr(self, rule_name, None) + if user_func is not None: + return user_func + + # backwards compatibility for subclass not calling __init__() + if not hasattr(self, "_user_func_overrides"): + warnings.warn("Subclasses of Transformer and Visitor should call super().__init__().", + DeprecationWarning) + self._user_func_overrides = {} + + # check cache + user_func = self._user_func_overrides.get(rule_name) + if user_func is not None: + return user_func + + for attr_name in dir(self): + if attr_name.startswith("_"): + continue + attr = getattr(self, attr_name) + if hasattr(attr, "_rule_name") and attr._rule_name == rule_name: + self._user_func_overrides[attr._rule_name] = attr + return attr + + return None + # Transformers class _Decoratable: @@ -64,7 +108,7 @@ def __class_getitem__(cls, _): return cls -class Transformer(_Decoratable, ABC, Generic[_T]): +class Transformer(_UserFuncLookup, _Decoratable, ABC, Generic[_T]): """Transformers visit each node of the tree, and run the appropriate method on it according to the node's data. Methods are provided by the user via inheritance, and called according to ``tree.data``. @@ -95,39 +139,38 @@ class Transformer(_Decoratable, ABC, Generic[_T]): __visit_tokens__ = True # For backwards compatibility def __init__(self, visit_tokens: bool=True) -> None: + super().__init__() self.__visit_tokens__ = visit_tokens def _call_userfunc(self, tree, new_children=None): # Assumes tree is already transformed children = new_children if new_children is not None else tree.children - try: - f = getattr(self, tree.data) - except AttributeError: + f = super()._look_up_user_func(tree.data) + if f is None: return self.__default__(tree.data, children, tree.meta) - else: - try: - wrapper = getattr(f, 'visit_wrapper', None) - if wrapper is not None: - return f.visit_wrapper(f, tree.data, children, tree.meta) - else: - return f(children) - except GrammarError: - raise - except Exception as e: - raise VisitError(tree.data, tree, e) - def _call_userfunc_token(self, token): try: - f = getattr(self, token.type) - except AttributeError: + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, children, tree.meta) + else: + return f(children) + except GrammarError: + raise + except Exception as e: + raise VisitError(tree.data, tree, e) + + def _call_userfunc_token(self, token): + f = super()._look_up_user_func(token.type) + if f is None: return self.__default_token__(token) - else: - try: - return f(token) - except GrammarError: - raise - except Exception as e: - raise VisitError(token.type, token, e) + + try: + return f(token) + except GrammarError: + raise + except Exception as e: + raise VisitError(token.type, token, e) def _transform_children(self, children): for c in children: @@ -204,6 +247,19 @@ def foo(self, children): assert composed_transformer.transform(t) == 'foobar' """ + prefix_format = "{}__{}" + + def _make_merged_method(prefix_with: str, to_wrap: Callable) -> Callable: + # Python methods don't allow attributes to be set, while that is allowed for functions. As + # a result, a wrapping function is needed to update the rule name. + # A factory function is needed to capture a reference to the method that is being wrapped. + @wraps(to_wrap) + def _merged_method(*args, **kwargs): + return to_wrap(*args, **kwargs) + + _merged_method._rule_name = prefix_format.format(prefix_with, to_wrap._rule_name) + return _merged_method + if base_transformer is None: base_transformer = Transformer() for prefix, transformer in transformers_to_merge.items(): @@ -213,10 +269,12 @@ def foo(self, children): continue if method_name.startswith("_") or method_name == "transform": continue - prefixed_method = prefix + "__" + method_name + prefixed_method = prefix_format.format(prefix, method_name) if hasattr(base_transformer, prefixed_method): raise AttributeError("Cannot merge: method '%s' appears more than once" % prefixed_method) + if hasattr(method, "_rule_name"): + method = _make_merged_method(prefix, method) setattr(base_transformer, prefixed_method, method) return base_transformer @@ -226,12 +284,10 @@ class InlineTransformer(Transformer): # XXX Deprecated def _call_userfunc(self, tree, new_children=None): # Assumes tree is already transformed children = new_children if new_children is not None else tree.children - try: - f = getattr(self, tree.data) - except AttributeError: + f = super()._look_up_user_func(tree.data) + if f is None: return self.__default__(tree.data, children, tree.meta) - else: - return f(*children) + return f(*children) class TransformerChain(Generic[_T]): @@ -317,9 +373,12 @@ def _transform_tree(self, tree): # Visitors -class VisitorBase: +class VisitorBase(_UserFuncLookup): def _call_userfunc(self, tree): - return getattr(self, tree.data, self.__default__)(tree) + f = super()._look_up_user_func(tree.data) + if f is None: + f = self.__default__ + return f(tree) def __default__(self, tree): """Default function that is called if there is no attribute matching ``tree.data`` @@ -379,7 +438,7 @@ def visit_topdown(self,tree: Tree) -> Tree: return tree -class Interpreter(_Decoratable, ABC, Generic[_T]): +class Interpreter(_UserFuncLookup, _Decoratable, ABC, Generic[_T]): """Interpreter walks the tree starting at the root. Visits the tree, starting with the root and finally the leaves (top-down) @@ -392,7 +451,10 @@ class Interpreter(_Decoratable, ABC, Generic[_T]): """ def visit(self, tree: Tree) -> _T: - f = getattr(self, tree.data) + f = super()._look_up_user_func(tree.data) + if f is None: + return self.__default__(tree) + wrapper = getattr(f, 'visit_wrapper', None) if wrapper is not None: return f.visit_wrapper(f, tree.data, tree.children, tree.meta) @@ -403,9 +465,6 @@ def visit_children(self, tree: Tree) -> List[_T]: return [self.visit(child) if isinstance(child, Tree) else child for child in tree.children] - def __getattr__(self, name): - return self.__default__ - def __default__(self, tree): return self.visit_children(tree) diff --git a/tests/test_trees.py b/tests/test_trees.py index dd95f6b8..47f39b15 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -11,7 +11,7 @@ from lark.tree import Tree from lark.lexer import Token from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \ - Transformer_InPlaceRecursive, Transformer_NonRecursive, merge_transformers + Transformer_InPlaceRecursive, Transformer_NonRecursive, merge_transformers, call_for class TestTrees(TestCase): @@ -51,12 +51,14 @@ def test_iter_subtrees_topdown(self): def test_visitor(self): class Visitor1(Visitor): def __init__(self): + super().__init__() self.nodes=[] def __default__(self,tree): self.nodes.append(tree) class Visitor1_Recursive(Visitor_Recursive): def __init__(self): + super().__init__() self.nodes=[] def __default__(self,tree): @@ -447,5 +449,143 @@ class T4(Transformer): with self.assertRaises(AttributeError): merge_transformers(T1(), module=T3()) + def test_call_for_transformer(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class T(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + res = T().transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_transformer_subclass(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class T(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class TT(T): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + res = TT().transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_transformer_merge(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('t2__f', ['1.1'])]), Tree('i', ['1'])]) + + class T1(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class T2(Transformer): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + merged_transformer = merge_transformers(T1(), t2=T2()) + + res = merged_transformer.transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_transformer_merge_no_base(self): + t = Tree('t1__add', [Tree('t1__sub', [Tree('t1__i', ['3']), Tree('t2__f', ['1.1'])]), Tree('t1__i', ['1'])]) + + class T1(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class T2(Transformer): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + merged_transformer = merge_transformers(t1=T1(), t2=T2()) + + res = merged_transformer.transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_visitor(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class V(Visitor): + def __init__(self) -> None: + super().__init__() + self.int_called = False + self.float_called = False + + @call_for("i") + def int_(self, _) -> None: + self.int_called = True + + @call_for("f") + def float_(self, _) -> None: + self.float_called = True + + v = V() + v.visit(t) + + self.assertTrue(v.int_called) + self.assertTrue(v.float_called) + + def test_call_for_visitor_subclass(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class V(Visitor): + def __init__(self) -> None: + super().__init__() + self.int_called = False + self.float_called = False + + @call_for("i") + def int_(self, _) -> None: + self.int_called = True + + class VV(V): + @call_for("f") + def float_(self, _) -> None: + self.float_called = True + + v = VV() + v.visit(t) + + self.assertTrue(v.int_called) + self.assertTrue(v.float_called) + if __name__ == '__main__': unittest.main()