diff --git a/src/core/sheerka/services/SheerkaRuleManager.py b/src/core/sheerka/services/SheerkaRuleManager.py index 0cb29c4..5fb7a0e 100644 --- a/src/core/sheerka/services/SheerkaRuleManager.py +++ b/src/core/sheerka/services/SheerkaRuleManager.py @@ -18,7 +18,8 @@ from core.tokenizer import Keywords, TokenKind, Token, IterParser from core.utils import index_tokens, COLORS, get_text_from_tokens, merge_dictionaries, merge_sets from evaluators.ConceptEvaluator import ConceptEvaluator from evaluators.PythonEvaluator import PythonEvaluator, Expando -from parsers.BaseExpressionParser import AndNode, ExpressionVisitor, VariableNode, ComparisonNode, FunctionNode +from parsers.BaseExpressionParser import AndNode, ExpressionVisitor, VariableNode, ComparisonNode, FunctionNode, \ + ComparisonType from parsers.BaseNodeParser import SourceCodeWithConceptNode, ConceptNode, SourceCodeNode from parsers.LogicalOperatorParser import LogicalOperatorParser from parsers.PythonParser import PythonNode @@ -1377,8 +1378,7 @@ class PythonConditionExprVisitor(ExpressionVisitor): self.variables[target] = var_name return var_name - def init_or_get_variable_from_name(self, variable_path: List[str], obj_variables): - + def get_variable_from_name(self, variable_path: List[str]): if len(variable_path) > 1: left = variable_path[:-1] right = [variable_path[-1]] @@ -1389,6 +1389,15 @@ class PythonConditionExprVisitor(ExpressionVisitor): return self.variables[var_name], ".".join(right) right.insert(0, left.pop()) + else: + return variable_path[0], ".".join(variable_path[1:]) + + return variable_path + + def init_or_get_variable_from_name(self, variable_path: List[str], obj_variables): + var_root, var_attr = self.get_variable_from_name(variable_path) + if var_root != variable_path[0]: + return var_root, var_attr if variable_path[0] not in self.variables: self.add_variable(variable_path[0]) @@ -1413,12 +1422,24 @@ class PythonConditionExprVisitor(ExpressionVisitor): return PythonConditionExprVisitorObj(source, source, {}, {expr_node.name}) def visit_ComparisonNode(self, expr_node: ComparisonNode): - if isinstance(expr_node.left, VariableNode): - source = expr_node.get_source() - return PythonConditionExprVisitorObj(source, source, {}, {expr_node.left.name}) - else: + if not isinstance(expr_node.left, VariableNode): raise FailedToCompileError([expr_node]) + res = evaluate(self.context, + expr_node.right.get_source(), + evaluators=CONDITIONS_VISITOR_EVALUATORS, + desc=None, + eval_body=False, + eval_where=False, + is_question=False, + expect_success=False, + stm=None) + res = expect_one(self.context, res) + if not res.status: + return FailedToCompileError([f"Cannot recognize '{expr_node.right.get_source()}'"]) + + return self.create_comparison_condition(expr_node.left.unpack(), expr_node.comp, res.value) + def visit_AndNode(self, expr_node: AndNode): current_visitor_obj = self.visit(expr_node.parts[0]) for node in expr_node.parts[1:]: @@ -1473,3 +1494,16 @@ class PythonConditionExprVisitor(ExpressionVisitor): concept_variables.update({k: v for k, v in concept.variables().items() if v is not NotInit}) return PythonConditionExprVisitorObj(source, source, {}, obj_variables) + + def create_comparison_condition(self, var_path, op, value): + var_root, var_attr = self.get_variable_from_name(var_path) + left = var_root + "." + var_attr + if op == ComparisonType.EQUALS: + if isinstance(value, Expando): + source = f"isinstance({left}, Expando) and {left} == {value.get_name()}" + return PythonConditionExprVisitorObj(source, source, {}, set()) + if isinstance(value, Concept): + return self.recognize_concept(var_path, value, {}) + else: + source = ComparisonNode.rebuild_source(left, op, value) + return PythonConditionExprVisitorObj(source, source, {}, {var_path[0]}) diff --git a/src/parsers/BaseExpressionParser.py b/src/parsers/BaseExpressionParser.py index 0bb55f9..e211db4 100644 --- a/src/parsers/BaseExpressionParser.py +++ b/src/parsers/BaseExpressionParser.py @@ -311,6 +311,35 @@ class ComparisonNode(ExprNode): def __str__(self): return f"{self.left} {self.comp} {self.right}" + @staticmethod + def rebuild_source(left, op, right): + if isinstance(right, str): + right = f"'{right}'" + + if op == ComparisonType.EQUALS: + return f"{left} == {right}" + + if op == ComparisonType.NOT_EQUAlS: + return f"{left} != {right}" + + if op == ComparisonType.LESS_THAN: + return f"{left} < {right}" + + if op == ComparisonType.LESS_THAN_OR_EQUALS: + return f"{left} <= {right}" + + if op == ComparisonType.GREATER_THAN: + return f"{left} > {right}" + + if op == ComparisonType.GREATER_THAN_OR_EQUALS: + return f"{left} >= {right}" + + if op == ComparisonType.IN: + return f"{left} in ({right})" + + if op == ComparisonType.NOT_IN: + return f"{left} not in ({right})" + @dataclass() class FunctionParameter: diff --git a/tests/core/test_SheerkaRuleManager.py b/tests/core/test_SheerkaRuleManager.py index a357189..dc30eae 100644 --- a/tests/core/test_SheerkaRuleManager.py +++ b/tests/core/test_SheerkaRuleManager.py @@ -1334,45 +1334,33 @@ isinstance(var, Concept) and var.key == 'hello __var__0'""" + \ assert self.sheerka.is_success(self.sheerka.objvalue(res)) @pytest.mark.parametrize("expression, variable_name, expected_compiled", [ - ( - "recognize(__ret.body, greetings)", - None, - "__x_00__ = __ret.body\nisinstance(__x_00__, Concept) and __x_00__.name == 'greetings'" - ), + # ( + # "recognize(__ret.body, greetings)", + # None, + # "__x_00__ = __ret.body\nisinstance(__x_00__, Concept) and __x_00__.name == 'greetings'" + # ), # ( # "recognize(__ret.body, c:|1001:)", # None, - # ["#__x_00__|__name__|'__ret'", - # "#__x_00__|body|#__x_01__", - # "#__x_01__|__is_concept__|True", - # "#__x_01__|id|'1001'"] + # "__x_00__ = __ret.body\nisinstance(__x_00__, Concept) and __x_00__.id == '1001'" # ), # ( # "recognize(__ret.body, c:greetings:)", # None, - # ["#__x_00__|__name__|'__ret'", - # "#__x_00__|body|#__x_01__", - # "#__x_01__|__is_concept__|True", - # "#__x_01__|name|'greetings'"] + # "__x_00__ = __ret.body\nisinstance(__x_00__, Concept) and __x_00__.name == 'greetings'" # ), # ( # "recognize(__ret.body, greetings) and __ret.body.a == 'my friend'", # "my friend", - # ["#__x_00__|__name__|'__ret'", - # "#__x_00__|body|#__x_01__", - # "#__x_01__|__is_concept__|True", - # "#__x_01__|name|'greetings'", - # "#__x_01__|a|'my friend'"] - # ), - # ( - # "recognize(__ret.body, greetings) and __ret.body.a == sheerka", - # "sheerka", - # ["#__x_00__|__name__|'__ret'", - # "#__x_00__|body|#__x_01__", - # "#__x_01__|__is_concept__|True", - # "#__x_01__|name|'greetings'", - # "#__x_01__|a|'__sheerka__'"] + # "__x_00__ = __ret.body\nisinstance(__x_00__, Concept) and __x_00__.name == 'greetings' and __x_00__.a == 'my friend'" # ), + ( + "recognize(__ret.body, greetings) and __ret.body.a == sheerka", + "sheerka", + """__x_00__ = __ret.body +__x_01__ = __x_00__.a +isinstance(__x_00__, Concept) and __x_00__.name == 'greetings' and isinstance(__x_01__, Expando) and __x_01__.name == 'sheerka'""" + ), # ( # "recognize(__ret.body, greetings) and __ret.body.a == foo", # "foo", diff --git a/tests/parsers/test_ExpressionParser.py b/tests/parsers/test_ExpressionParser.py index 367aa6f..41ca5b8 100644 --- a/tests/parsers/test_ExpressionParser.py +++ b/tests/parsers/test_ExpressionParser.py @@ -3,7 +3,7 @@ import pytest from core.builtin_concepts_ids import BuiltinConcepts from core.sheerka.services.SheerkaExecute import ParserInput from core.tokenizer import Tokenizer -from parsers.BaseExpressionParser import VariableNode +from parsers.BaseExpressionParser import VariableNode, ComparisonNode from parsers.BaseParser import ErrorSink from parsers.ExpressionParser import ExpressionParser from tests.TestUsingMemoryBasedSheerka import TestUsingMemoryBasedSheerka @@ -56,7 +56,7 @@ class TestExpressionParser(TestUsingMemoryBasedSheerka): ("func(var1.attr1 > var2.attr2)", FN("func(", ")", [GT(VAR("var1.attr1"), VAR("var2.attr2"))])), ("func1(var1) and func2(var2)", AND(FN("func1(", ")", [VAR("var1")]), FN("func2(", (")", 1), [VAR("var2")]))), ("__ret", VAR("__ret")), - #("func1().func2()", []) + # ("func1().func2()", []) ]) def test_i_can_parse_input(self, expression, expected): sheerka, context, parser, parser_input, error_sink = self.init_parser_with_source(expression) @@ -91,3 +91,14 @@ class TestExpressionParser(TestUsingMemoryBasedSheerka): assert not error_sink.has_error assert parsed == get_expr_node_from_test_node(expression, EXPR("var1 + var2")) + + @pytest.mark.parametrize("expression, expected", [ + ("ret.status in ('a', 1 , func())", "new_var in ('a', 1 , func())"), + ("ret.status not in ('a', 1 , func())", "new_var not in ('a', 1 , func())"), + + ]) + def test_i_can_rebuild_source(self, expression, expected): + sheerka, context, parser, parser_input, error_sink = self.init_parser_with_source(expression) + parsed = parser.parse_input(context, parser_input, error_sink) + + assert ComparisonNode.rebuild_source("new_var", parsed.comp, parsed.right.get_source()) == expected