From 1b4fd0da1c6c5f1ed3ffeb43e00fc38c0bc520d0 Mon Sep 17 00:00:00 2001 From: Patrick Lannigan Date: Fri, 7 Jan 2022 07:10:20 -0500 Subject: [PATCH 1/3] Support calling methods that doesn't match a rule name --- lark/__init__.py | 2 +- lark/visitors.py | 139 ++++++++++++++++++++++++++++++++------------ tests/test_trees.py | 90 +++++++++++++++++++++++++++- 3 files changed, 191 insertions(+), 40 deletions(-) diff --git a/lark/__init__.py b/lark/__init__.py index cd9ea9f6..0552bf85 100644 --- a/lark/__init__.py +++ b/lark/__init__.py @@ -1,6 +1,6 @@ from .utils import logger from .tree import Tree -from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive +from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive, call_for from .exceptions import (ParseError, LexError, GrammarError, UnexpectedToken, UnexpectedInput, UnexpectedCharacters, UnexpectedEOF, LarkError) from .lexer import Token diff --git a/lark/visitors.py b/lark/visitors.py index 262562f3..fbc198d3 100644 --- a/lark/visitors.py +++ b/lark/visitors.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any +from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any, Mapping, ClassVar from abc import ABC from functools import wraps, update_wrapper @@ -35,6 +35,69 @@ def __repr__(self): Discard = _DiscardType() +# User function lookup and aliases + + +class _UserFuncOverride: + def __init__(self, rule_name: str, user_func: Callable) -> None: + """ + Initialize an instance. + + Parameters: + rule_name: Name of the rule this should be replacing. + user_func: User function to call when rule is encountered. Will be a plain function, + NOT be bound to a class instance. + """ + self.rule_name = rule_name + self.user_func = user_func + update_wrapper(self, user_func) + + def __call__(self, *args, **kwargs): + return self.user_func(*args, **kwargs) + + def __get__(self, instance, owner=None): + # bind user function with instance when function is called directly + return self.user_func.__get__(instance, owner) + + +def call_for(rule_name: str) -> Callable: + def _call_for(func: Callable) -> Callable: + return _UserFuncOverride(rule_name, func) + + return _call_for + + +class _UserFuncLookup: + _user_func_overrides: ClassVar[Mapping[str, _UserFuncOverride]] + + def __init_subclass__(cls, **kwargs): + # cls is the subclass being initialized, so each subclass gets a unique override mapping + super().__init_subclass__(**kwargs) + all_overrides = {} + # Aggregate any overrides from parent classes (skipping the subclass & object). + # Reverse order so child overrides have precedence over parents. + for hierarchy_class in reversed(cls.__mro__[1:-1]): + aliases = getattr(hierarchy_class, "_user_func_overrides", None) + if aliases is not None: + all_overrides.update(aliases) + + all_overrides.update(cls._collect_user_func_overrides()) + cls._user_func_overrides = all_overrides + + @classmethod + def _collect_user_func_overrides(cls) -> Mapping[str, _UserFuncOverride]: + aliases = {} + for obj in cls.__dict__.values(): + if isinstance(obj, _UserFuncOverride): + aliases[obj.rule_name] = obj + return aliases + + def _look_up_user_func(self, rule_name: str) -> Optional[Callable]: + override = self._user_func_overrides.get(rule_name) + if override is None: + return getattr(self, rule_name, None) + return override.user_func.__get__(self, self.__class__) + # Transformers class _Decoratable: @@ -64,7 +127,7 @@ def __class_getitem__(cls, _): return cls -class Transformer(_Decoratable, ABC, Generic[_T]): +class Transformer(_UserFuncLookup, _Decoratable, ABC, Generic[_T]): """Transformers visit each node of the tree, and run the appropriate method on it according to the node's data. Methods are provided by the user via inheritance, and called according to ``tree.data``. @@ -100,34 +163,32 @@ def __init__(self, visit_tokens: bool=True) -> None: def _call_userfunc(self, tree, new_children=None): # Assumes tree is already transformed children = new_children if new_children is not None else tree.children - try: - f = getattr(self, tree.data) - except AttributeError: + f = super()._look_up_user_func(tree.data) + if f is None: return self.__default__(tree.data, children, tree.meta) - else: - try: - wrapper = getattr(f, 'visit_wrapper', None) - if wrapper is not None: - return f.visit_wrapper(f, tree.data, children, tree.meta) - else: - return f(children) - except GrammarError: - raise - except Exception as e: - raise VisitError(tree.data, tree, e) - def _call_userfunc_token(self, token): try: - f = getattr(self, token.type) - except AttributeError: + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, children, tree.meta) + else: + return f(children) + except GrammarError: + raise + except Exception as e: + raise VisitError(tree.data, tree, e) + + def _call_userfunc_token(self, token): + f = super()._look_up_user_func(token.type) + if f is None: return self.__default_token__(token) - else: - try: - return f(token) - except GrammarError: - raise - except Exception as e: - raise VisitError(token.type, token, e) + + try: + return f(token) + except GrammarError: + raise + except Exception as e: + raise VisitError(token.type, token, e) def _transform_children(self, children): for c in children: @@ -204,6 +265,7 @@ def foo(self, children): assert composed_transformer.transform(t) == 'foobar' """ + # TODO: investigate how this works with overrides if base_transformer is None: base_transformer = Transformer() for prefix, transformer in transformers_to_merge.items(): @@ -226,12 +288,10 @@ class InlineTransformer(Transformer): # XXX Deprecated def _call_userfunc(self, tree, new_children=None): # Assumes tree is already transformed children = new_children if new_children is not None else tree.children - try: - f = getattr(self, tree.data) - except AttributeError: + f = super()._look_up_user_func(tree.data) + if f is None: return self.__default__(tree.data, children, tree.meta) - else: - return f(*children) + return f(*children) class TransformerChain(Generic[_T]): @@ -317,9 +377,12 @@ def _transform_tree(self, tree): # Visitors -class VisitorBase: +class VisitorBase(_UserFuncLookup): def _call_userfunc(self, tree): - return getattr(self, tree.data, self.__default__)(tree) + f = super()._look_up_user_func(tree.data) + if f is None: + f = self.__default__ + return f(tree) def __default__(self, tree): """Default function that is called if there is no attribute matching ``tree.data`` @@ -379,7 +442,7 @@ def visit_topdown(self,tree: Tree) -> Tree: return tree -class Interpreter(_Decoratable, ABC, Generic[_T]): +class Interpreter(_UserFuncLookup, _Decoratable, ABC, Generic[_T]): """Interpreter walks the tree starting at the root. Visits the tree, starting with the root and finally the leaves (top-down) @@ -392,7 +455,10 @@ class Interpreter(_Decoratable, ABC, Generic[_T]): """ def visit(self, tree: Tree) -> _T: - f = getattr(self, tree.data) + f = super()._look_up_user_func(tree.data) + if f is None: + return self.__default__(tree) + wrapper = getattr(f, 'visit_wrapper', None) if wrapper is not None: return f.visit_wrapper(f, tree.data, tree.children, tree.meta) @@ -403,9 +469,6 @@ def visit_children(self, tree: Tree) -> List[_T]: return [self.visit(child) if isinstance(child, Tree) else child for child in tree.children] - def __getattr__(self, name): - return self.__default__ - def __default__(self, tree): return self.visit_children(tree) diff --git a/tests/test_trees.py b/tests/test_trees.py index dd95f6b8..12b0bedd 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -11,7 +11,7 @@ from lark.tree import Tree from lark.lexer import Token from lark.visitors import Visitor, Visitor_Recursive, Transformer, Interpreter, visit_children_decor, v_args, Discard, Transformer_InPlace, \ - Transformer_InPlaceRecursive, Transformer_NonRecursive, merge_transformers + Transformer_InPlaceRecursive, Transformer_NonRecursive, merge_transformers, call_for class TestTrees(TestCase): @@ -447,5 +447,93 @@ class T4(Transformer): with self.assertRaises(AttributeError): merge_transformers(T1(), module=T3()) + def test_call_for_transformer(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class T(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + res = T().transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_transformer_subclass(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class T(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class TT(T): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + res = TT().transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_visitor(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class V(Visitor): + def __init__(self) -> None: + self.int_called = False + self.float_called = False + + @call_for("i") + def int_(self, _) -> None: + self.int_called = True + + @call_for("f") + def float_(self, _) -> None: + self.float_called = True + + v = V() + v.visit(t) + + self.assertTrue(v.int_called) + self.assertTrue(v.float_called) + + def test_call_for_visitor_subclass(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) + + class V(Visitor): + def __init__(self) -> None: + self.int_called = False + self.float_called = False + + @call_for("i") + def int_(self, _) -> None: + self.int_called = True + + class VV(V): + @call_for("f") + def float_(self, _) -> None: + self.float_called = True + + v = VV() + v.visit(t) + + self.assertTrue(v.int_called) + self.assertTrue(v.float_called) + if __name__ == '__main__': unittest.main() From fa00eb6faa1d7ee8b1308480d1fcd977dede3509 Mon Sep 17 00:00:00 2001 From: Patrick Lannigan Date: Tue, 11 Jan 2022 08:17:18 -0500 Subject: [PATCH 2/3] Pivot implmentaiton Use a marker on the function instead of a wrapper class Look up overrides only when needed and cache result --- lark/visitors.py | 102 +++++++++++++++++++++----------------------- tests/test_trees.py | 48 +++++++++++++++++++++ 2 files changed, 97 insertions(+), 53 deletions(-) diff --git a/lark/visitors.py b/lark/visitors.py index fbc198d3..5b32ba00 100644 --- a/lark/visitors.py +++ b/lark/visitors.py @@ -1,6 +1,7 @@ -from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any, Mapping, ClassVar +from typing import TypeVar, Tuple, List, Callable, Generic, Type, Union, Optional, Any, Dict from abc import ABC from functools import wraps, update_wrapper +import warnings from .utils import combine_alternatives from .tree import Tree @@ -35,68 +36,48 @@ def __repr__(self): Discard = _DiscardType() -# User function lookup and aliases - - -class _UserFuncOverride: - def __init__(self, rule_name: str, user_func: Callable) -> None: - """ - Initialize an instance. - - Parameters: - rule_name: Name of the rule this should be replacing. - user_func: User function to call when rule is encountered. Will be a plain function, - NOT be bound to a class instance. - """ - self.rule_name = rule_name - self.user_func = user_func - update_wrapper(self, user_func) - - def __call__(self, *args, **kwargs): - return self.user_func(*args, **kwargs) - - def __get__(self, instance, owner=None): - # bind user function with instance when function is called directly - return self.user_func.__get__(instance, owner) +# User function lookup and overrides def call_for(rule_name: str) -> Callable: def _call_for(func: Callable) -> Callable: - return _UserFuncOverride(rule_name, func) + func._rule_name = rule_name + return func return _call_for class _UserFuncLookup: - _user_func_overrides: ClassVar[Mapping[str, _UserFuncOverride]] - - def __init_subclass__(cls, **kwargs): - # cls is the subclass being initialized, so each subclass gets a unique override mapping - super().__init_subclass__(**kwargs) - all_overrides = {} - # Aggregate any overrides from parent classes (skipping the subclass & object). - # Reverse order so child overrides have precedence over parents. - for hierarchy_class in reversed(cls.__mro__[1:-1]): - aliases = getattr(hierarchy_class, "_user_func_overrides", None) - if aliases is not None: - all_overrides.update(aliases) - - all_overrides.update(cls._collect_user_func_overrides()) - cls._user_func_overrides = all_overrides + _user_func_overrides: Dict[str, Callable] - @classmethod - def _collect_user_func_overrides(cls) -> Mapping[str, _UserFuncOverride]: - aliases = {} - for obj in cls.__dict__.values(): - if isinstance(obj, _UserFuncOverride): - aliases[obj.rule_name] = obj - return aliases + def __init__(self): + self._user_func_overrides = {} def _look_up_user_func(self, rule_name: str) -> Optional[Callable]: - override = self._user_func_overrides.get(rule_name) - if override is None: - return getattr(self, rule_name, None) - return override.user_func.__get__(self, self.__class__) + user_func = getattr(self, rule_name, None) + if user_func is not None: + return user_func + + # backwards compatibility for subclass not calling __init__() + if not hasattr(self, "_user_func_overrides"): + warnings.warn("Subclasses of Transformer and Visitor should call super().__init__().", + DeprecationWarning) + self._user_func_overrides = {} + + # check cache + user_func = self._user_func_overrides.get(rule_name) + if user_func is not None: + return user_func + + for attr_name in dir(self): + if attr_name.startswith("_"): + continue + attr = getattr(self, attr_name) + if hasattr(attr, "_rule_name") and attr._rule_name == rule_name: + self._user_func_overrides[attr._rule_name] = attr + return attr + + return None # Transformers @@ -158,6 +139,7 @@ class Transformer(_UserFuncLookup, _Decoratable, ABC, Generic[_T]): __visit_tokens__ = True # For backwards compatibility def __init__(self, visit_tokens: bool=True) -> None: + super().__init__() self.__visit_tokens__ = visit_tokens def _call_userfunc(self, tree, new_children=None): @@ -265,7 +247,19 @@ def foo(self, children): assert composed_transformer.transform(t) == 'foobar' """ - # TODO: investigate how this works with overrides + prefix_format = "{}__{}" + + def _make_merged_method(prefix_with: str, to_wrap: Callable) -> Callable: + # Python methods don't allow attributes to be set, while that is allowed for functions. As + # a result, a wrapping function is needed to update the rule name. + # A factory function is needed to capture a reference to the method that is being wrapped. + @wraps(to_wrap) + def _merged_method(*args, **kwargs): + return to_wrap(*args, **kwargs) + + _merged_method._rule_name = prefix_format.format(prefix_with, to_wrap._rule_name) + return _merged_method + if base_transformer is None: base_transformer = Transformer() for prefix, transformer in transformers_to_merge.items(): @@ -275,10 +269,12 @@ def foo(self, children): continue if method_name.startswith("_") or method_name == "transform": continue - prefixed_method = prefix + "__" + method_name + prefixed_method = prefix_format.format(prefix, method_name) if hasattr(base_transformer, prefixed_method): raise AttributeError("Cannot merge: method '%s' appears more than once" % prefixed_method) + if hasattr(method, "_rule_name"): + method = _make_merged_method(prefix, method) setattr(base_transformer, prefixed_method, method) return base_transformer diff --git a/tests/test_trees.py b/tests/test_trees.py index 12b0bedd..5daf2e88 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -490,6 +490,54 @@ def float_(self, values) -> float: res = TT().transform(t) self.assertEqual(res, 2.9) + def test_call_for_transformer_merge(self): + t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('t2__f', ['1.1'])]), Tree('i', ['1'])]) + + class T1(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class T2(Transformer): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + merged_transformer = merge_transformers(T1(), t2=T2()) + + res = merged_transformer.transform(t) + self.assertEqual(res, 2.9) + + def test_call_for_transformer_merge_no_base(self): + t = Tree('t1__add', [Tree('t1__sub', [Tree('t1__i', ['3']), Tree('t2__f', ['1.1'])]), Tree('t1__i', ['1'])]) + + class T1(Transformer): + @call_for("i") + def int_(self, values) -> int: + return int(values[0]) + + def sub(self, values): + return values[0] - values[1] + + def add(self, values): + return sum(values) + + class T2(Transformer): + @call_for("f") + def float_(self, values) -> float: + return float(values[0]) + + merged_transformer = merge_transformers(t1=T1(), t2=T2()) + + res = merged_transformer.transform(t) + self.assertEqual(res, 2.9) + def test_call_for_visitor(self): t = Tree('add', [Tree('sub', [Tree('i', ['3']), Tree('f', ['1.1'])]), Tree('i', ['1'])]) From fb0b46c5f2973fab671d87e1687c262cd62db8cc Mon Sep 17 00:00:00 2001 From: Patrick Lannigan Date: Mon, 17 Jan 2022 19:41:38 -0500 Subject: [PATCH 3/3] Add super init calls that are now expected --- examples/calc.py | 1 + lark/load_grammar.py | 4 ++++ lark/reconstruct.py | 1 + lark/tools/nearley.py | 1 + lark/tree_templates.py | 1 + tests/test_trees.py | 4 ++++ 6 files changed, 12 insertions(+) diff --git a/examples/calc.py b/examples/calc.py index 9e9aa78f..3474f34c 100644 --- a/examples/calc.py +++ b/examples/calc.py @@ -46,6 +46,7 @@ class CalculateTree(Transformer): number = float def __init__(self): + super().__init__() self.vars = {} def assign_var(self, name, value): diff --git a/lark/load_grammar.py b/lark/load_grammar.py index d06e88cc..6682b9ef 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -196,6 +196,7 @@ @inline_args class EBNF_to_BNF(Transformer_InPlace): def __init__(self): + super().__init__() self.new_rules = [] self.rules_cache = {} self.prefix = 'anon' @@ -421,6 +422,7 @@ class PrepareAnonTerminals(Transformer_InPlace): """Create a unique list of anonymous terminals. Attempt to give meaningful names to them when we add them""" def __init__(self, terminals): + super().__init__() self.terminals = terminals self.term_set = {td.name for td in self.terminals} self.term_reverse = {td.pattern: td for td in terminals} @@ -476,6 +478,7 @@ class _ReplaceSymbols(Transformer_InPlace): """Helper for ApplyTemplates""" def __init__(self): + super().__init__() self.names = {} def value(self, c): @@ -494,6 +497,7 @@ class ApplyTemplates(Transformer_InPlace): """Apply the templates, creating new rules that represent the used templates""" def __init__(self, rule_defs): + super().__init__() self.rule_defs = rule_defs self.replacer = _ReplaceSymbols() self.created_templates = set() diff --git a/lark/reconstruct.py b/lark/reconstruct.py index 02b49476..5acff709 100644 --- a/lark/reconstruct.py +++ b/lark/reconstruct.py @@ -27,6 +27,7 @@ class WriteTokensTransformer(Transformer_InPlace): term_subs: Dict[str, Callable[[Symbol], str]] def __init__(self, tokens: Dict[str, TerminalDef], term_subs: Dict[str, Callable[[Symbol], str]]) -> None: + super().__init__() self.tokens = tokens self.term_subs = term_subs diff --git a/lark/tools/nearley.py b/lark/tools/nearley.py index 8b8ef89e..394a64d0 100644 --- a/lark/tools/nearley.py +++ b/lark/tools/nearley.py @@ -53,6 +53,7 @@ def _get_rulename(name): @v_args(inline=True) class NearleyToLark(Transformer): def __init__(self): + super().__init__() self._count = 0 self.extra_rules = {} self.extra_rules_rev = {} diff --git a/lark/tree_templates.py b/lark/tree_templates.py index 8db8b35c..3ea1ca5a 100644 --- a/lark/tree_templates.py +++ b/lark/tree_templates.py @@ -76,6 +76,7 @@ def _match_tree_template(self, template, tree): class _ReplaceVars(Transformer): def __init__(self, conf, vars): + super().__init__() self._conf = conf self._vars = vars diff --git a/tests/test_trees.py b/tests/test_trees.py index 5daf2e88..47f39b15 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -51,12 +51,14 @@ def test_iter_subtrees_topdown(self): def test_visitor(self): class Visitor1(Visitor): def __init__(self): + super().__init__() self.nodes=[] def __default__(self,tree): self.nodes.append(tree) class Visitor1_Recursive(Visitor_Recursive): def __init__(self): + super().__init__() self.nodes=[] def __default__(self,tree): @@ -543,6 +545,7 @@ def test_call_for_visitor(self): class V(Visitor): def __init__(self) -> None: + super().__init__() self.int_called = False self.float_called = False @@ -565,6 +568,7 @@ def test_call_for_visitor_subclass(self): class V(Visitor): def __init__(self) -> None: + super().__init__() self.int_called = False self.float_called = False