Added SheerkaComparisonManager
This commit is contained in:
@@ -0,0 +1,206 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cache.Cache import Cache
|
||||
from cache.ListCache import ListCache
|
||||
from core.builtin_concepts import BuiltinConcepts
|
||||
from core.sheerka.services.sheerka_service import ServiceObj, BaseService
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComparisonObj(ServiceObj):
|
||||
"""
|
||||
Order to store
|
||||
"""
|
||||
property: str # property to compare
|
||||
a: int # id of concept a
|
||||
b: int # id of concept b
|
||||
op: str # comparison operation
|
||||
context: str = "#" # context when the comparison is right
|
||||
|
||||
|
||||
class SheerkaComparisonManager(BaseService):
|
||||
"""
|
||||
Manage partitioning of concepts
|
||||
"""
|
||||
NAME = "ComparisonManager"
|
||||
COMPARISON_ENTRY = "Comparison"
|
||||
RESOLVED_COMPARISON_ENTRY = "Resolved_Comparison"
|
||||
|
||||
def __init__(self, sheerka):
|
||||
super().__init__(sheerka)
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(prop_name, comparison_context):
|
||||
return f"{prop_name}|{comparison_context}"
|
||||
|
||||
@staticmethod
|
||||
def _compute_weights(comparison_objs):
|
||||
"""
|
||||
For every element in greater_than_s, give it a weight
|
||||
if weight(a) > weight(b) it means that a > b
|
||||
:param comparison_objs: list of greater than objects
|
||||
:return:
|
||||
"""
|
||||
|
||||
values = {}
|
||||
for comparison_obj in comparison_objs:
|
||||
values[comparison_obj.a] = 1
|
||||
values[comparison_obj.b] = 1
|
||||
|
||||
for _ in range(len(comparison_objs)):
|
||||
for comparison_obj in comparison_objs:
|
||||
if comparison_obj.op == ">":
|
||||
values[comparison_obj.a] = values[comparison_obj.b] + 1
|
||||
else:
|
||||
values[comparison_obj.b] = values[comparison_obj.a] + 1
|
||||
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _get_partition(weighted_concepts):
|
||||
|
||||
res = {}
|
||||
for k, v in weighted_concepts.items():
|
||||
res.setdefault(v, []).append(k)
|
||||
return res
|
||||
|
||||
def _inner_add_comparison(self, comparison_obj):
|
||||
key = self._compute_key(comparison_obj.property, comparison_obj.context)
|
||||
previous = self.sheerka.cache_manager.get(self.COMPARISON_ENTRY, key)
|
||||
|
||||
new = previous.copy() if previous else []
|
||||
new.append(comparison_obj)
|
||||
|
||||
cycles = self.detect_cycles(new)
|
||||
if cycles:
|
||||
concepts_in_cycle = [self.sheerka.get_by_id(c) for c in cycles]
|
||||
chicken_an_egg = self.sheerka.new(BuiltinConcepts.CHICKEN_AND_EGG, body=concepts_in_cycle)
|
||||
return self.sheerka.ret(self.NAME, False, chicken_an_egg)
|
||||
|
||||
self.sheerka.cache_manager.put(self.RESOLVED_COMPARISON_ENTRY, key, self._compute_weights(new))
|
||||
self.sheerka.cache_manager.put(self.COMPARISON_ENTRY, key, comparison_obj)
|
||||
|
||||
return self.sheerka.ret(self.NAME, True, self.sheerka.new(BuiltinConcepts.SUCCESS))
|
||||
|
||||
def initialize(self):
|
||||
cache = ListCache(default=lambda k: self.sheerka.sdp.get(self.COMPARISON_ENTRY, k))
|
||||
self.sheerka.cache_manager.register_cache(self.COMPARISON_ENTRY, cache, True, True)
|
||||
|
||||
cache = Cache()
|
||||
self.sheerka.cache_manager.register_cache(self.RESOLVED_COMPARISON_ENTRY, cache, persist=False)
|
||||
|
||||
self.sheerka.bind_service_method(self, SheerkaComparisonManager.is_greater_than)
|
||||
self.sheerka.bind_service_method(self, SheerkaComparisonManager.is_less_than)
|
||||
self.sheerka.bind_service_method(self, SheerkaComparisonManager.get_partition)
|
||||
self.sheerka.bind_service_method(self, SheerkaComparisonManager.get_concepts_weights)
|
||||
|
||||
def is_greater_than(self, context, prop_name, concept_a, concept_b, comparison_context="#"):
|
||||
"""
|
||||
Records that the property of concept a is greater than concept b's one
|
||||
:param context:
|
||||
:param prop_name:
|
||||
:param concept_a:
|
||||
:param concept_b:
|
||||
:param comparison_context:
|
||||
:return:
|
||||
"""
|
||||
context.log(f"Setting concept {concept_a} is greater than {concept_b}", who=self.NAME)
|
||||
|
||||
event_digest = context.event.get_digest()
|
||||
comparison_obj = ComparisonObj(event_digest, prop_name, concept_a.id, concept_b.id, ">", comparison_context)
|
||||
return self._inner_add_comparison(comparison_obj)
|
||||
|
||||
def is_less_than(self, context, prop_name, concept_a, concept_b, comparison_context="#"):
|
||||
"""
|
||||
Records that the property of concept a is lesser than concept b's one
|
||||
:param context:
|
||||
:param prop_name:
|
||||
:param concept_a:
|
||||
:param concept_b:
|
||||
:param comparison_context:
|
||||
:return:
|
||||
"""
|
||||
context.log(f"Setting concept {concept_a} is less than {concept_b}", who=self.NAME)
|
||||
|
||||
event_digest = context.event.get_digest()
|
||||
comparison_obj = ComparisonObj(event_digest, prop_name, concept_a.id, concept_b.id, "<", comparison_context)
|
||||
return self._inner_add_comparison(comparison_obj)
|
||||
|
||||
def get_partition(self, prop_name, comparison_context="#"):
|
||||
weighted_concept = self.get_concepts_weights(prop_name, comparison_context)
|
||||
|
||||
return self._get_partition(weighted_concept)
|
||||
|
||||
def get_concepts_weights(self, prop_name, comparison_context="#"):
|
||||
weighted_concept = self.sheerka.cache_manager.get(
|
||||
self.RESOLVED_COMPARISON_ENTRY,
|
||||
self._compute_key(prop_name, comparison_context))
|
||||
|
||||
if weighted_concept is None:
|
||||
key = self._compute_key(prop_name, comparison_context)
|
||||
entries = self.sheerka.cache_manager.get(self.COMPARISON_ENTRY, key)
|
||||
|
||||
if entries is None:
|
||||
return {}
|
||||
else:
|
||||
weighted_concept = self._compute_weights(entries)
|
||||
self.sheerka.cache_manager.put(self.RESOLVED_COMPARISON_ENTRY, key, weighted_concept)
|
||||
|
||||
return weighted_concept
|
||||
|
||||
@staticmethod
|
||||
def detect_cycles(comparison_objs):
|
||||
"""
|
||||
# Thanks to Divyanshu Mehta for contributing this code
|
||||
# https://www.geeksforgeeks.org/detect-cycle-in-a-graph/?ref=lbp
|
||||
:param comparison_objs:
|
||||
:return:
|
||||
"""
|
||||
latest = comparison_objs[-1]
|
||||
if latest.op == "=":
|
||||
return None
|
||||
|
||||
def get_graph_and_vertices():
|
||||
_graph = {}
|
||||
_vertices = set()
|
||||
for obj in comparison_objs:
|
||||
if obj.op == "=":
|
||||
continue
|
||||
|
||||
_vertices.add(obj.a)
|
||||
_vertices.add(obj.b)
|
||||
if obj.op == ">":
|
||||
_graph.setdefault(obj.a, []).append(obj.b)
|
||||
else:
|
||||
_graph.setdefault(obj.b, []).append(obj.a)
|
||||
return _graph, _vertices
|
||||
|
||||
def is_cyclic(v):
|
||||
# Mark current node as visited and
|
||||
# adds to recursion stack
|
||||
visited[v] = True
|
||||
rec_stack[v] = True
|
||||
|
||||
# Recur for all neighbours
|
||||
# if any neighbour is visited and in
|
||||
# recStack then graph is cyclic
|
||||
if v in graph:
|
||||
for neighbour in graph[v]:
|
||||
if not visited[neighbour]:
|
||||
if is_cyclic(neighbour):
|
||||
return True
|
||||
elif rec_stack[neighbour]:
|
||||
return True
|
||||
|
||||
# The node needs to be poped from
|
||||
# recursion stack before function ends
|
||||
rec_stack[v] = False
|
||||
return False
|
||||
|
||||
graph, vertices = get_graph_and_vertices()
|
||||
visited = {k: False for k in vertices}
|
||||
rec_stack = {k: False for k in vertices}
|
||||
|
||||
if is_cyclic(latest.a): # only need to check from the latest add, since the graph was not cyclic before
|
||||
return [k for k, v in rec_stack.items() if v]
|
||||
return None
|
||||
Reference in New Issue
Block a user