from __future__ import annotations from itertools import product from typing import TYPE_CHECKING, Generator, Union from core.builtin_concepts_ids import BuiltinConcepts from core.concept import Concept from core.global_symbols import NotInit from core.rule import Rule, ACTION_TYPE_PRINT from core.utils import as_bag from evaluators.PythonEvaluator import Expando from sheerkapickle.utils import is_primitive from sheerkarete.alpha import AlphaMemory from sheerkarete.beta import ReteNode, BetaMemory from sheerkarete.bind_node import BindNode from sheerkarete.common import WME, Match, V from sheerkarete.conditions import Condition, NegatedCondition, NegatedConjunctiveConditions, FilterCondition, \ BindCondition from sheerkarete.filter_node import FilterNode from sheerkarete.join_node import JoinNode from sheerkarete.ncc_node import NccNode, NccPartnerNode from sheerkarete.negative_node import NegativeNode from sheerkarete.pnode import PNode if TYPE_CHECKING: # pragma: no cover from typing import Dict from typing import Tuple from typing import List from typing import Set from typing import Hashable FACT_ID = "##fact_id##" class ReteNetwork: """ A Rete Network to store all the facts and productions to compute matches. """ def __init__(self): self.alpha_hash: Dict[Tuple[Hashable, Hashable, Hashable], List[AlphaMemory]] = {} self.beta_root = ReteNode() self.pnodes: List[PNode] = [] # list of all production nodes self.rules: Set[Rule] = set() # set of all know rules self.working_memory: Set[WME] = set() self.fact_counter: int = 0 self.facts: Dict[str, object] = {} self.attributes_by_id = {} # keep track of requested conditions attributes, for a given id self.default_attributes = set() # keep track of requested attributes, when the id is not given @property def matches(self) -> Generator[Match, None, None]: for pnode in self.pnodes: for t in pnode.activations: yield Match(pnode, t) def build_or_share_alpha_memory(self, condition): """ :type condition: Condition :rtype: AlphaMemory """ key = condition.get_key() # return existing alpha memory if it exists if key in self.alpha_hash: for amem in self.alpha_hash[key]: if amem.condition == condition: return amem # or create a new one amem = AlphaMemory(key, condition) self.alpha_hash.setdefault(key, []).append(amem) # fire already created WME if needed for w in self.working_memory: if condition.test(w): amem.activation(w) return amem def build_or_share_beta_memory(self, parent): """ Create or reuse a BetaMemory """ # search for an existing one for child in parent.children: # if isinstance(child, BetaMemory): # Don't include subclasses if type(child) == BetaMemory: return child node = BetaMemory(parent=parent) parent.children.append(node) self.update_new_node_with_matches_from_above(node) return node def build_or_share_join_node(self, parent: BetaMemory, amem: AlphaMemory, condition: Condition) -> JoinNode: """ Creates or reuse a JoinNode :param parent: parent beta memory :param amem: parent alpha memory :param condition: condition for the join :returns: """ # search for already created join node for child in parent.all_children: if type(child) == JoinNode and child.amem == amem and child.condition == condition: return child node = JoinNode(children=[], parent=parent, amem=amem, condition=condition) parent.children.append(node) parent.all_children.append(node) amem.successors.append(node) amem.reference_count += 1 node.update_nearest_ancestor_with_same_amem() # little optimisation. No need to bind if there is no wme in parent if not parent.items: amem.successors.remove(node) elif not amem.items: parent.children.remove(node) return node def build_or_share_negative_node(self, parent: JoinNode, amem: AlphaMemory, condition: NegatedCondition) -> NegativeNode: # search for already created join node for child in parent.children: if isinstance(child, NegativeNode) and child.amem == amem and child.condition == condition: return child node = NegativeNode(parent=parent, amem=amem, condition=condition) parent.children.append(node) amem.successors.append(node) amem.reference_count += 1 node.update_nearest_ancestor_with_same_amem() self.update_new_node_with_matches_from_above(node) # little optimisation. No need to bind if there is no wme in parent if not node.items: amem.successors.remove(node) return node def build_or_share_ncc_nodes(self, parent: JoinNode, ncc: NegatedConjunctiveConditions, earlier_conds: List[Condition]) -> NccNode: # search for already created join node bottom_of_subnetwork = self.build_or_share_network_for_conditions(parent, ncc, earlier_conds) for child in parent.children: if isinstance(child, NccNode) and child.partner.parent == bottom_of_subnetwork: return child ncc_partner = NccPartnerNode(parent=bottom_of_subnetwork) ncc_node = NccNode(partner=ncc_partner, children=[], parent=parent) ncc_partner.ncc_node = ncc_node parent.children.insert(0, ncc_node) bottom_of_subnetwork.children.append(ncc_partner) ncc_partner.number_of_conditions = ncc.number_of_conditions self.update_new_node_with_matches_from_above(ncc_node) self.update_new_node_with_matches_from_above(ncc_partner) return ncc_node def build_or_share_filter_node(self, parent: ReteNode, f: FilterCondition) -> FilterNode: # search for already created join node for child in parent.children: if isinstance(child, FilterNode) and child.func == f.func: return child node = FilterNode([], parent, f.func, self) parent.children.append(node) return node def build_or_share_bind_node(self, parent: ReteNode, b: BindCondition) -> BindNode: # search for already created join node for child in parent.children: if isinstance(child, BindNode) and child.func == b.func and child.bind == b.to: return child node = BindNode([], parent, b.func, b.to, self) parent.children.append(node) return node def build_or_share_p_node(self, parent: JoinNode, rule: Rule) -> Union[PNode, None]: """ Create or reuse a production node :param parent: parent join node :param rule: rule that will be fired on activation :return: returns None if the PNode already exists """ for child in parent.children: if isinstance(child, PNode): child.rules.append(rule) rule.rete_p_nodes.append(child) return None node = PNode(rule=rule, parent=parent) parent.children.append(node) self.update_new_node_with_matches_from_above(node) rule.rete_p_nodes.append(node) return node def build_or_share_network_for_conditions(self, parent, conditions, earlier_conditions) -> ReteNode: current_node = parent conds_higher_up = earlier_conditions # Explanation on vars_ids_mappings # conditions = [Condition(V("x"), "__name__", "fact_name"), # Condition(V("x"), "attr1", "value1"), # Condition(V("x"), "attr1", "value1")] # V(x) actually refers to a object named 'fact_name' # self.conditions_attributes_by_id must be updated accordingly vars_ids_mappings = {} for cond in conditions: # update requested attributes for a fact if isinstance(cond, Condition): # Manage list of requested attributes when using __name__ indirection if isinstance(cond.identifier, V) and cond.attribute == "__name__": vars_ids_mappings[cond.identifier] = cond.value # Manage list of requested attributes when bounding a new variable if (cond.identifier in vars_ids_mappings and isinstance(cond.attribute, str) and isinstance(cond.value, V)): vars_ids_mappings[cond.value] = f"{vars_ids_mappings[cond.identifier]}.{cond.attribute}" identifier = vars_ids_mappings[cond.identifier] if cond.identifier in vars_ids_mappings else \ cond.identifier if not isinstance(cond.identifier, V) else \ None if identifier: attr = "*" if isinstance(cond.attribute, V) else cond.attribute self.attributes_by_id.setdefault(identifier, []).append(attr) elif not isinstance(cond.attribute, V): self.default_attributes.add(cond.attribute) # create the alpha memory (if needed), beta memory and join node if isinstance(cond, Condition) and not isinstance(cond, NegatedCondition): am = self.build_or_share_alpha_memory(cond) current_node = self.build_or_share_beta_memory(current_node) current_node = self.build_or_share_join_node(current_node, am, cond) elif isinstance(cond, NegatedCondition): am = self.build_or_share_alpha_memory(cond) current_node = self.build_or_share_negative_node(current_node, am, cond) elif isinstance(cond, NegatedConjunctiveConditions): current_node = self.build_or_share_ncc_nodes(current_node, cond, conds_higher_up) elif isinstance(cond, FilterCondition): current_node = self.build_or_share_filter_node(current_node, cond) elif isinstance(cond, BindCondition): current_node = self.build_or_share_bind_node(current_node, cond) conds_higher_up.append(cond) return current_node def get_rete_conditions(self, rule): """ Gets the conditions from a rule It's in fact the list of disjunctions Not sure yet which component will hold this functionality """ if hasattr(rule, "get_rete_disjunctions"): return rule.get_rete_disjunctions() raise NotImplementedError("") def add_rule(self, rule: Rule): if rule.id is None: raise ValueError("Rule has no id, cannot add") if (not rule.metadata.is_enabled or not rule.metadata.is_compiled or rule.metadata.action_type == ACTION_TYPE_PRINT): return if rule.rete_net: raise ValueError("Rule is already added") rule.rete_net = self self.rules.add(rule) for full_condition in self.get_rete_conditions(rule): conditions = full_condition.conditions current_node = self.build_or_share_network_for_conditions(self.beta_root, conditions, []) p_node = self.build_or_share_p_node(current_node, rule) if p_node is not None: self.pnodes.append(p_node) def remove_rule(self, rule: Rule): """ Removes a pnode from the network """ if rule.rete_net is None: return # Remove production self.rules.remove(rule) for pnode in rule.rete_p_nodes: pnode.rules.remove(rule) if len(pnode.rules) == 0: self.delete_node_and_any_unused_ancestors(pnode) self.pnodes.remove(pnode) rule.p_nodes = [] def add_wme(self, wme: WME) -> None: if wme in self.working_memory: return keys = product([wme.identifier, '*'], [wme.attribute, '*'], [wme.value, '*']) for key in keys: if key in self.alpha_hash: for amem in reversed(self.alpha_hash[key]): amem.activation(wme) self.working_memory.add(wme) def remove_wme(self, wme: WME) -> None: for stored_wme in self.working_memory: if wme == stored_wme: wme = stored_wme break for am in wme.amems: am.items.remove(wme) if not am.items: for node in am.successors: if isinstance(node, JoinNode) and not isinstance(node, NegativeNode): node.parent.children.remove(node) while wme.tokens: t = wme.tokens[0] t.delete_token_and_descendents() for jr in wme.negative_join_results: jr.owner.join_results.remove(jr) if not jr.owner.join_results: if jr.owner.node and jr.owner.node.children is not None: for child in jr.owner.node.children: child.left_activation(jr.owner, None, jr.owner.binding) self.working_memory.remove(wme) def remove_wme_by_fact_id(self, identifier: str) -> None: to_remove = [wme for wme in self.working_memory if wme.identifier == identifier or wme.identifier.startswith(identifier + ".")] for wme in to_remove: self.remove_wme(wme) def add_obj(self, name, obj, fact_id=None, use_bag=False): """ Adds a new object to the working memory """ def inner_add_vme(name_, fact_id_, attr_, value_): if value_ is NotInit: pass elif attr_ != "self" and isinstance(value_, Concept): new_name = f"{name_}.{attr_}" new_fact_id = f"{fact_id_}.{attr_}" self.add_wme(WME(fact_id_, attr_, new_fact_id)) self.add_obj(new_name, value_, new_fact_id) else: self.add_wme(WME(fact_id_, attr_, value_)) if fact_id is None: if hasattr(obj, FACT_ID): raise ValueError("Object already has an id, cannot add") fact_id = f"f-{self.fact_counter:05}" setattr(obj, FACT_ID, fact_id) self.facts[fact_id] = obj self.fact_counter += 1 requested_attributes = "*" if use_bag else \ self.attributes_by_id[name] if name in self.attributes_by_id else \ self.default_attributes for attribute in requested_attributes: if attribute == "*": bag = as_bag(obj) for k, v in bag.items(): inner_add_vme(name, fact_id, k, v) elif attribute == "__name__": self.add_wme(WME(fact_id, "__name__", name)) elif attribute == "__is_concept__": self.add_wme(WME(fact_id, "__is_concept__", isinstance(obj, Concept))) else: try: value = getattr(obj, attribute) if (isinstance(value, Concept) and value.key == BuiltinConcepts.SHEERKA or isinstance(value, Expando) and value.get_name() == "sheerka"): value = "__sheerka__" if is_primitive(value): self.add_wme(WME(fact_id, attribute, value)) else: inner_add_vme(name, fact_id, attribute, value) except AttributeError: pass def remove_obj(self, obj): if not hasattr(obj, FACT_ID) or (fact_id := getattr(obj, FACT_ID)) not in self.facts: raise ValueError("Fact has no id or does not exist in network.") self.remove_wme_by_fact_id(fact_id) del self.facts[fact_id] delattr(obj, FACT_ID) def update_new_node_with_matches_from_above(self, new_node: ReteNode) -> None: parent = new_node.parent if parent == self.beta_root: new_node.left_activation(None, None, {}) elif isinstance(parent, BetaMemory) and not isinstance(parent, (NccNode, NegativeNode)): for tok in parent.items: new_node.left_activation(token=tok) elif isinstance(parent, JoinNode) and not isinstance(parent, NegativeNode): saved_list_of_children = parent.children parent.children = [new_node] for item in parent.amem.items: parent.right_activation(item) parent.children = saved_list_of_children elif isinstance(parent, NegativeNode): for token in parent.items: if not token.join_results: new_node.left_activation(token, None, token.binding) elif isinstance(parent, NccNode): for token in parent.items: if not token.ncc_results: new_node.left_activation(token, None, token.binding) elif isinstance(parent, (BindNode, FilterNode)): saved_list_of_children = parent.children parent.children = [new_node] self.update_new_node_with_matches_from_above(parent) parent.children = saved_list_of_children def delete_alpha_memory(self, amem: AlphaMemory): del self.alpha_hash[amem.key] def delete_node_and_any_unused_ancestors(self, node: ReteNode): if isinstance(node, NccNode): self.delete_node_and_any_unused_ancestors(node.partner) if isinstance(node, BetaMemory): for item in node.items: item.delete_token_and_descendents() if isinstance(node, NccPartnerNode): for item in node.new_result_buffer: item.delete_token_and_descendents() if isinstance(node, JoinNode) and not isinstance(node, NegativeNode): if not node.right_unlinked: node.amem.successors.remove(node) node.amem.reference_count -= 1 if node.amem.reference_count == 0: self.delete_alpha_memory(node.amem) if not node.left_unlinked: node.parent.children.remove(node) node.parent.all_children.remove(node) if not node.parent.all_children: self.delete_node_and_any_unused_ancestors(node.parent) elif node.parent: node.parent.children.remove(node) if not node.parent.children: self.delete_node_and_any_unused_ancestors(node.parent)