500 lines
19 KiB
Python
500 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
from itertools import product
|
|
from typing import TYPE_CHECKING, Generator, Union
|
|
|
|
from core.builtin_concepts_ids import BuiltinConcepts
|
|
from core.concept import Concept
|
|
from core.global_symbols import NotInit
|
|
from core.rule import Rule, ACTION_TYPE_PRINT
|
|
from core.utils import as_bag
|
|
from evaluators.PythonEvaluator import Expando
|
|
from sheerkapickle.utils import is_primitive
|
|
from sheerkarete.alpha import AlphaMemory
|
|
from sheerkarete.beta import ReteNode, BetaMemory
|
|
from sheerkarete.bind_node import BindNode
|
|
from sheerkarete.common import WME, Match, V
|
|
from sheerkarete.conditions import Condition, NegatedCondition, NegatedConjunctiveConditions, FilterCondition, \
|
|
BindCondition
|
|
from sheerkarete.filter_node import FilterNode
|
|
from sheerkarete.join_node import JoinNode
|
|
from sheerkarete.ncc_node import NccNode, NccPartnerNode
|
|
from sheerkarete.negative_node import NegativeNode
|
|
from sheerkarete.pnode import PNode
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from typing import Dict
|
|
from typing import Tuple
|
|
from typing import List
|
|
from typing import Set
|
|
from typing import Hashable
|
|
|
|
FACT_ID = "##fact_id##"
|
|
|
|
|
|
class ReteNetwork:
|
|
"""
|
|
A Rete Network to store all the facts and productions to compute matches.
|
|
"""
|
|
|
|
def __init__(self):
|
|
|
|
self.alpha_hash: Dict[Tuple[Hashable, Hashable, Hashable], List[AlphaMemory]] = {}
|
|
self.beta_root = ReteNode()
|
|
self.pnodes: List[PNode] = [] # list of all production nodes
|
|
self.rules: Set[Rule] = set() # set of all know rules
|
|
self.working_memory: Set[WME] = set()
|
|
|
|
self.fact_counter: int = 0
|
|
self.facts: Dict[str, object] = {}
|
|
|
|
self.attributes_by_id = {} # keep track of requested conditions attributes, for a given id
|
|
self.default_attributes = set() # keep track of requested attributes, when the id is not given
|
|
|
|
@property
|
|
def matches(self) -> Generator[Match, None, None]:
|
|
for pnode in self.pnodes:
|
|
for t in pnode.activations:
|
|
yield Match(pnode, t)
|
|
|
|
def build_or_share_alpha_memory(self, condition):
|
|
"""
|
|
:type condition: Condition
|
|
:rtype: AlphaMemory
|
|
"""
|
|
|
|
key = condition.get_key()
|
|
|
|
# return existing alpha memory if it exists
|
|
if key in self.alpha_hash:
|
|
for amem in self.alpha_hash[key]:
|
|
if amem.condition == condition:
|
|
return amem
|
|
|
|
# or create a new one
|
|
amem = AlphaMemory(key, condition)
|
|
self.alpha_hash.setdefault(key, []).append(amem)
|
|
|
|
# fire already created WME if needed
|
|
for w in self.working_memory:
|
|
if condition.test(w):
|
|
amem.activation(w)
|
|
|
|
return amem
|
|
|
|
def build_or_share_beta_memory(self, parent):
|
|
"""
|
|
Create or reuse a BetaMemory
|
|
"""
|
|
# search for an existing one
|
|
for child in parent.children:
|
|
# if isinstance(child, BetaMemory): # Don't include subclasses
|
|
if type(child) == BetaMemory:
|
|
return child
|
|
|
|
node = BetaMemory(parent=parent)
|
|
parent.children.append(node)
|
|
self.update_new_node_with_matches_from_above(node)
|
|
return node
|
|
|
|
def build_or_share_join_node(self, parent: BetaMemory, amem: AlphaMemory, condition: Condition) -> JoinNode:
|
|
"""
|
|
Creates or reuse a JoinNode
|
|
:param parent: parent beta memory
|
|
:param amem: parent alpha memory
|
|
:param condition: condition for the join
|
|
:returns:
|
|
"""
|
|
|
|
# search for already created join node
|
|
for child in parent.all_children:
|
|
if type(child) == JoinNode and child.amem == amem and child.condition == condition:
|
|
return child
|
|
|
|
node = JoinNode(children=[], parent=parent, amem=amem, condition=condition)
|
|
|
|
parent.children.append(node)
|
|
parent.all_children.append(node)
|
|
|
|
amem.successors.append(node)
|
|
amem.reference_count += 1
|
|
|
|
node.update_nearest_ancestor_with_same_amem()
|
|
|
|
# little optimisation. No need to bind if there is no wme in parent
|
|
if not parent.items:
|
|
amem.successors.remove(node)
|
|
elif not amem.items:
|
|
parent.children.remove(node)
|
|
|
|
return node
|
|
|
|
def build_or_share_negative_node(self,
|
|
parent: JoinNode,
|
|
amem: AlphaMemory,
|
|
condition: NegatedCondition) -> NegativeNode:
|
|
# search for already created join node
|
|
for child in parent.children:
|
|
if isinstance(child, NegativeNode) and child.amem == amem and child.condition == condition:
|
|
return child
|
|
|
|
node = NegativeNode(parent=parent, amem=amem, condition=condition)
|
|
parent.children.append(node)
|
|
|
|
amem.successors.append(node)
|
|
amem.reference_count += 1
|
|
|
|
node.update_nearest_ancestor_with_same_amem()
|
|
self.update_new_node_with_matches_from_above(node)
|
|
|
|
# little optimisation. No need to bind if there is no wme in parent
|
|
if not node.items:
|
|
amem.successors.remove(node)
|
|
|
|
return node
|
|
|
|
def build_or_share_ncc_nodes(self,
|
|
parent: JoinNode,
|
|
ncc: NegatedConjunctiveConditions,
|
|
earlier_conds: List[Condition]) -> NccNode:
|
|
|
|
# search for already created join node
|
|
bottom_of_subnetwork = self.build_or_share_network_for_conditions(parent, ncc, earlier_conds)
|
|
for child in parent.children:
|
|
if isinstance(child, NccNode) and child.partner.parent == bottom_of_subnetwork:
|
|
return child
|
|
|
|
ncc_partner = NccPartnerNode(parent=bottom_of_subnetwork)
|
|
ncc_node = NccNode(partner=ncc_partner, children=[], parent=parent)
|
|
ncc_partner.ncc_node = ncc_node
|
|
parent.children.insert(0, ncc_node)
|
|
bottom_of_subnetwork.children.append(ncc_partner)
|
|
ncc_partner.number_of_conditions = ncc.number_of_conditions
|
|
self.update_new_node_with_matches_from_above(ncc_node)
|
|
self.update_new_node_with_matches_from_above(ncc_partner)
|
|
return ncc_node
|
|
|
|
def build_or_share_filter_node(self,
|
|
parent: ReteNode,
|
|
f: FilterCondition) -> FilterNode:
|
|
# search for already created join node
|
|
for child in parent.children:
|
|
if isinstance(child, FilterNode) and child.func == f.func:
|
|
return child
|
|
|
|
node = FilterNode([], parent, f.func, self)
|
|
parent.children.append(node)
|
|
return node
|
|
|
|
def build_or_share_bind_node(self,
|
|
parent: ReteNode,
|
|
b: BindCondition) -> BindNode:
|
|
# search for already created join node
|
|
for child in parent.children:
|
|
if isinstance(child, BindNode) and child.func == b.func and child.bind == b.to:
|
|
return child
|
|
|
|
node = BindNode([], parent, b.func, b.to, self)
|
|
parent.children.append(node)
|
|
|
|
return node
|
|
|
|
def build_or_share_p_node(self, parent: JoinNode, rule: Rule) -> Union[PNode, None]:
|
|
"""
|
|
Create or reuse a production node
|
|
:param parent: parent join node
|
|
:param rule: rule that will be fired on activation
|
|
:return: returns None if the PNode already exists
|
|
"""
|
|
for child in parent.children:
|
|
if isinstance(child, PNode):
|
|
child.rules.append(rule)
|
|
rule.rete_p_nodes.append(child)
|
|
return None
|
|
|
|
node = PNode(rule=rule, parent=parent)
|
|
parent.children.append(node)
|
|
self.update_new_node_with_matches_from_above(node)
|
|
rule.rete_p_nodes.append(node)
|
|
return node
|
|
|
|
def build_or_share_network_for_conditions(self, parent, conditions, earlier_conditions) -> ReteNode:
|
|
current_node = parent
|
|
conds_higher_up = earlier_conditions
|
|
|
|
# Explanation on vars_ids_mappings
|
|
# conditions = [Condition(V("x"), "__name__", "fact_name"),
|
|
# Condition(V("x"), "attr1", "value1"),
|
|
# Condition(V("x"), "attr1", "value1")]
|
|
# V(x) actually refers to a object named 'fact_name'
|
|
# self.conditions_attributes_by_id must be updated accordingly
|
|
vars_ids_mappings = {}
|
|
|
|
for cond in conditions:
|
|
# update requested attributes for a fact
|
|
if isinstance(cond, Condition):
|
|
|
|
# Manage list of requested attributes when using __name__ indirection
|
|
if isinstance(cond.identifier, V) and cond.attribute == "__name__":
|
|
vars_ids_mappings[cond.identifier] = cond.value
|
|
|
|
# Manage list of requested attributes when bounding a new variable
|
|
if (cond.identifier in vars_ids_mappings and
|
|
isinstance(cond.attribute, str) and
|
|
isinstance(cond.value, V)):
|
|
vars_ids_mappings[cond.value] = f"{vars_ids_mappings[cond.identifier]}.{cond.attribute}"
|
|
|
|
identifier = vars_ids_mappings[cond.identifier] if cond.identifier in vars_ids_mappings else \
|
|
cond.identifier if not isinstance(cond.identifier, V) else \
|
|
None
|
|
if identifier:
|
|
attr = "*" if isinstance(cond.attribute, V) else cond.attribute
|
|
self.attributes_by_id.setdefault(identifier, []).append(attr)
|
|
elif not isinstance(cond.attribute, V):
|
|
self.default_attributes.add(cond.attribute)
|
|
|
|
# create the alpha memory (if needed), beta memory and join node
|
|
if isinstance(cond, Condition) and not isinstance(cond, NegatedCondition):
|
|
am = self.build_or_share_alpha_memory(cond)
|
|
current_node = self.build_or_share_beta_memory(current_node)
|
|
current_node = self.build_or_share_join_node(current_node, am, cond)
|
|
|
|
elif isinstance(cond, NegatedCondition):
|
|
am = self.build_or_share_alpha_memory(cond)
|
|
current_node = self.build_or_share_negative_node(current_node, am, cond)
|
|
|
|
elif isinstance(cond, NegatedConjunctiveConditions):
|
|
current_node = self.build_or_share_ncc_nodes(current_node, cond, conds_higher_up)
|
|
|
|
elif isinstance(cond, FilterCondition):
|
|
current_node = self.build_or_share_filter_node(current_node, cond)
|
|
|
|
elif isinstance(cond, BindCondition):
|
|
current_node = self.build_or_share_bind_node(current_node, cond)
|
|
|
|
conds_higher_up.append(cond)
|
|
|
|
return current_node
|
|
|
|
def get_rete_conditions(self, rule):
|
|
"""
|
|
Gets the conditions from a rule
|
|
It's in fact the list of disjunctions
|
|
Not sure yet which component will hold this functionality
|
|
"""
|
|
if hasattr(rule, "get_rete_disjunctions"):
|
|
return rule.get_rete_disjunctions()
|
|
|
|
raise NotImplementedError("")
|
|
|
|
def add_rule(self, rule: Rule):
|
|
|
|
if rule.id is None:
|
|
raise ValueError("Rule has no id, cannot add")
|
|
|
|
if (not rule.metadata.is_enabled or
|
|
not rule.metadata.is_compiled or
|
|
rule.metadata.action_type == ACTION_TYPE_PRINT):
|
|
return
|
|
|
|
if rule.rete_net:
|
|
raise ValueError("Rule is already added")
|
|
|
|
rule.rete_net = self
|
|
self.rules.add(rule)
|
|
|
|
for full_condition in self.get_rete_conditions(rule):
|
|
conditions = full_condition.conditions
|
|
current_node = self.build_or_share_network_for_conditions(self.beta_root, conditions, [])
|
|
p_node = self.build_or_share_p_node(current_node, rule)
|
|
if p_node is not None:
|
|
self.pnodes.append(p_node)
|
|
|
|
def remove_rule(self, rule: Rule):
|
|
"""
|
|
Removes a pnode from the network
|
|
"""
|
|
if rule.rete_net is None:
|
|
return
|
|
|
|
# Remove production
|
|
self.rules.remove(rule)
|
|
|
|
for pnode in rule.rete_p_nodes:
|
|
pnode.rules.remove(rule)
|
|
if len(pnode.rules) == 0:
|
|
self.delete_node_and_any_unused_ancestors(pnode)
|
|
self.pnodes.remove(pnode)
|
|
|
|
rule.p_nodes = []
|
|
|
|
def add_wme(self, wme: WME) -> None:
|
|
if wme in self.working_memory:
|
|
return
|
|
|
|
keys = product([wme.identifier, '*'],
|
|
[wme.attribute, '*'],
|
|
[wme.value, '*'])
|
|
|
|
for key in keys:
|
|
if key in self.alpha_hash:
|
|
for amem in reversed(self.alpha_hash[key]):
|
|
amem.activation(wme)
|
|
|
|
self.working_memory.add(wme)
|
|
|
|
def remove_wme(self, wme: WME) -> None:
|
|
for stored_wme in self.working_memory:
|
|
if wme == stored_wme:
|
|
wme = stored_wme
|
|
break
|
|
|
|
for am in wme.amems:
|
|
am.items.remove(wme)
|
|
if not am.items:
|
|
for node in am.successors:
|
|
if isinstance(node, JoinNode) and not isinstance(node, NegativeNode):
|
|
node.parent.children.remove(node)
|
|
|
|
while wme.tokens:
|
|
t = wme.tokens[0]
|
|
t.delete_token_and_descendents()
|
|
|
|
for jr in wme.negative_join_results:
|
|
jr.owner.join_results.remove(jr)
|
|
if not jr.owner.join_results:
|
|
if jr.owner.node and jr.owner.node.children is not None:
|
|
for child in jr.owner.node.children:
|
|
child.left_activation(jr.owner, None, jr.owner.binding)
|
|
|
|
self.working_memory.remove(wme)
|
|
|
|
def remove_wme_by_fact_id(self, identifier: str) -> None:
|
|
to_remove = [wme for wme in self.working_memory if wme.identifier ==
|
|
identifier or wme.identifier.startswith(identifier + ".")]
|
|
for wme in to_remove:
|
|
self.remove_wme(wme)
|
|
|
|
def add_obj(self, name, obj, fact_id=None, use_bag=False):
|
|
"""
|
|
Adds a new object to the working memory
|
|
"""
|
|
|
|
def inner_add_vme(name_, fact_id_, attr_, value_):
|
|
if value_ is NotInit:
|
|
pass
|
|
elif attr_ != "self" and isinstance(value_, Concept):
|
|
new_name = f"{name_}.{attr_}"
|
|
new_fact_id = f"{fact_id_}.{attr_}"
|
|
self.add_wme(WME(fact_id_, attr_, new_fact_id))
|
|
self.add_obj(new_name, value_, new_fact_id)
|
|
else:
|
|
self.add_wme(WME(fact_id_, attr_, value_))
|
|
|
|
if fact_id is None:
|
|
if hasattr(obj, FACT_ID):
|
|
raise ValueError("Object already has an id, cannot add")
|
|
|
|
fact_id = f"f-{self.fact_counter:05}"
|
|
setattr(obj, FACT_ID, fact_id)
|
|
self.facts[fact_id] = obj
|
|
self.fact_counter += 1
|
|
|
|
requested_attributes = "*" if use_bag else \
|
|
self.attributes_by_id[name] if name in self.attributes_by_id else \
|
|
self.default_attributes
|
|
|
|
for attribute in requested_attributes:
|
|
if attribute == "*":
|
|
bag = as_bag(obj)
|
|
for k, v in bag.items():
|
|
inner_add_vme(name, fact_id, k, v)
|
|
elif attribute == "__name__":
|
|
self.add_wme(WME(fact_id, "__name__", name))
|
|
elif attribute == "__is_concept__":
|
|
self.add_wme(WME(fact_id, "__is_concept__", isinstance(obj, Concept)))
|
|
else:
|
|
try:
|
|
value = getattr(obj, attribute)
|
|
if (isinstance(value, Concept) and value.key == BuiltinConcepts.SHEERKA or
|
|
isinstance(value, Expando) and value.get_name() == "sheerka"):
|
|
value = "__sheerka__"
|
|
if is_primitive(value):
|
|
self.add_wme(WME(fact_id, attribute, value))
|
|
else:
|
|
inner_add_vme(name, fact_id, attribute, value)
|
|
except AttributeError:
|
|
pass
|
|
|
|
def remove_obj(self, obj):
|
|
if not hasattr(obj, FACT_ID) or (fact_id := getattr(obj, FACT_ID)) not in self.facts:
|
|
raise ValueError("Fact has no id or does not exist in network.")
|
|
|
|
self.remove_wme_by_fact_id(fact_id)
|
|
del self.facts[fact_id]
|
|
delattr(obj, FACT_ID)
|
|
|
|
def update_new_node_with_matches_from_above(self, new_node: ReteNode) -> None:
|
|
parent = new_node.parent
|
|
if parent == self.beta_root:
|
|
new_node.left_activation(None, None, {})
|
|
elif isinstance(parent, BetaMemory) and not isinstance(parent, (NccNode, NegativeNode)):
|
|
for tok in parent.items:
|
|
new_node.left_activation(token=tok)
|
|
elif isinstance(parent, JoinNode) and not isinstance(parent, NegativeNode):
|
|
saved_list_of_children = parent.children
|
|
parent.children = [new_node]
|
|
for item in parent.amem.items:
|
|
parent.right_activation(item)
|
|
parent.children = saved_list_of_children
|
|
elif isinstance(parent, NegativeNode):
|
|
for token in parent.items:
|
|
if not token.join_results:
|
|
new_node.left_activation(token, None, token.binding)
|
|
elif isinstance(parent, NccNode):
|
|
for token in parent.items:
|
|
if not token.ncc_results:
|
|
new_node.left_activation(token, None, token.binding)
|
|
elif isinstance(parent, (BindNode, FilterNode)):
|
|
saved_list_of_children = parent.children
|
|
parent.children = [new_node]
|
|
self.update_new_node_with_matches_from_above(parent)
|
|
parent.children = saved_list_of_children
|
|
|
|
def delete_alpha_memory(self, amem: AlphaMemory):
|
|
del self.alpha_hash[amem.key]
|
|
|
|
def delete_node_and_any_unused_ancestors(self, node: ReteNode):
|
|
if isinstance(node, NccNode):
|
|
self.delete_node_and_any_unused_ancestors(node.partner)
|
|
|
|
if isinstance(node, BetaMemory):
|
|
for item in node.items:
|
|
item.delete_token_and_descendents()
|
|
|
|
if isinstance(node, NccPartnerNode):
|
|
for item in node.new_result_buffer:
|
|
item.delete_token_and_descendents()
|
|
|
|
if isinstance(node, JoinNode) and not isinstance(node, NegativeNode):
|
|
if not node.right_unlinked:
|
|
node.amem.successors.remove(node)
|
|
|
|
node.amem.reference_count -= 1
|
|
|
|
if node.amem.reference_count == 0:
|
|
self.delete_alpha_memory(node.amem)
|
|
|
|
if not node.left_unlinked:
|
|
node.parent.children.remove(node)
|
|
|
|
node.parent.all_children.remove(node)
|
|
|
|
if not node.parent.all_children:
|
|
self.delete_node_and_any_unused_ancestors(node.parent)
|
|
|
|
elif node.parent:
|
|
node.parent.children.remove(node)
|
|
if not node.parent.children:
|
|
self.delete_node_and_any_unused_ancestors(node.parent)
|