Files
Sheerka/src/parsers/state_machine.py
T

333 lines
12 KiB
Python

from dataclasses import dataclass, field
from typing import Any, Literal
from common.utils import str_concept
from core.ExecutionContext import ExecutionContext
from core.concept import ConceptMetadata
from parsers.ParserInput import ParserInput
from parsers.parser_utils import UnexpectedEof, UnexpectedToken, get_text_from_tokens
from parsers.tokenizer import Token
@dataclass
class MetadataToken:
"""
Class that represents a text that is recognized as a concept
We keep track of the start and the end position
"""
metadata: ConceptMetadata
start: int
end: int
resolution_method: Literal["name", "key", "id"]
parser: str
def __repr__(self):
return f"(MetadataToken metadata={str_concept(self.metadata, drop_name=True)}, " + \
f"start={self.start}, end={self.end}, method={self.resolution_method}, origin={self.parser})"
def __eq__(self, other):
if not isinstance(other, MetadataToken):
return False
return self.metadata.id == other.metadata.id \
and self.start == other.start \
and self.end == other.end \
and self.parser == other.parser
def __hash__(self):
return hash((self.metadata.id, self.start, self.end, self.parser))
@dataclass
class UnrecognizedToken:
"""
Class that represents a text that is not recognized
We keep track of the start and the end position
"""
buffer: str
start: int
end: int
@dataclass
class StateResult:
next_state: str | None
forks: list = None
@dataclass
class ConceptToRecognize:
"""
Holds information about the concept to recognize
"""
metadata: ConceptMetadata
expected_tokens: list
resolution_method: Literal["name", "key", "id"] # which attribute was used to resolve the concept
@dataclass
class StateMachineContext:
context: ExecutionContext
parser_input: ParserInput
get_metadata_from_first_token: Any
buffer: list[Token] = field(default_factory=list)
buffer_start_pos: int = -1
concept_to_recognize: ConceptToRecognize | None = None
result: list = field(default_factory=list)
errors: list = field(default_factory=list)
def get_clones(self, concepts_to_recognize):
return [StateMachineContext(self.context,
self.parser_input.clone(),
self.get_metadata_from_first_token,
self.buffer.copy(),
self.buffer_start_pos,
concept,
self.result.copy(),
self.errors.copy())
for concept in concepts_to_recognize]
def to_debug(self):
return {"pos": self.parser_input.pos,
"token": self.parser_input.token,
"buffer": [token.value for token in self.buffer],
"concept": str_concept(self.concept_to_recognize.metadata) if self.concept_to_recognize else None,
"result": self.result.copy()}
class State:
def __init__(self, name, next_states):
self.name = name
self.next_states = next_states
def run(self, state_context: StateMachineContext) -> StateResult:
pass
@staticmethod
def get_forks(next_state, states_contexts: list[StateMachineContext]):
"""
Create on fork item for every state context
:param next_state:
:type next_state:
:param states_contexts:
:type states_contexts:
:return:
:rtype:
"""
return [(next_state, state_context) for state_context in states_contexts]
def __repr__(self):
return f"(State '{self.name}' -> {self.next_states})"
class Start(State):
def run(self, state_context) -> StateResult:
# Start state
# give some logs and ask for the next state
return StateResult(self.next_states[0])
def __repr__(self):
return f"(StartState '{self.name}' -> '{self.next_states[0]}')"
class PrepareReadTokens(State):
def run(self, state_context: StateMachineContext) -> StateResult:
state_context.buffer.clear()
state_context.buffer_start_pos = state_context.parser_input.pos + 1
return StateResult(self.next_states[0])
class ReadTokens(State):
def run(self, state_context) -> StateResult:
if not state_context.parser_input.next_token(False):
return StateResult("eof")
# try to get the possible concepts to recognize
concepts = state_context.get_metadata_from_first_token(state_context.context,
state_context.parser_input.token)
forks = self.get_forks("concepts found", state_context.get_clones(concepts)) if concepts else None
state_context.buffer.append(state_context.parser_input.token)
return StateResult(self.name, forks)
class ManageUnrecognized(State):
def run(self, state_context) -> StateResult:
if state_context.buffer:
buffer_as_str = get_text_from_tokens(state_context.buffer)
if len(state_context.result) > 0 and isinstance(old := state_context.result[-1], UnrecognizedToken):
state_context.result[-1] = UnrecognizedToken(old.buffer + buffer_as_str,
old.start,
state_context.parser_input.pos - 1)
else:
state_context.result.append(UnrecognizedToken(buffer_as_str,
state_context.buffer_start_pos,
state_context.parser_input.pos - 1))
return StateResult(self.next_states[0])
class ReadConcept(State):
def run(self, state_context) -> StateResult:
start = state_context.parser_input.pos
for expected in state_context.concept_to_recognize.expected_tokens:
if not state_context.parser_input.next_token(False):
# eof before the concept is recognized
state_context.errors.append(UnexpectedEof(expected, state_context.parser_input.token))
state_context.concept_to_recognize = None
return StateResult(self.next_states[0])
token = state_context.parser_input.token
if token.value != expected:
# token mismatch
state_context.errors.append(UnexpectedToken(token, expected))
state_context.concept_to_recognize = None
return StateResult(self.next_states[0])
state_context.result.append(MetadataToken(state_context.concept_to_recognize.metadata,
start,
state_context.parser_input.pos,
state_context.concept_to_recognize.resolution_method,
"simple"))
state_context.concept_to_recognize = None
return StateResult(self.next_states[0])
class End(State):
def run(self, state_context) -> StateResult:
return StateResult(None)
def __repr__(self):
return f"(EndState '{self.name}')"
@dataclass
class ExecutionPathHistory:
from_state: str
execution_context_debug: dict
to_state: str = ""
forks: list[tuple] = None
parents: list = None
def clone(self, parent_path_id):
parents = self.parents.copy() if self.parents else []
parents.append(parent_path_id)
return ExecutionPathHistory(self.from_state,
self.execution_context_debug.copy(),
self.to_state,
self.forks.copy() if self.forks else None,
parents)
def __repr__(self):
return "History(from '{0}', to '{1}', using {2}, forks={3}, parents={4}".format(
self.from_state,
self.to_state,
self.execution_context_debug,
len(self.forks) if self.forks else 0,
self.parents)
@dataclass
class ExecutionPath:
path_id: int
execution_context: Any
current_workflow: str
current_state: str
history: list[ExecutionPathHistory]
ended: bool = False
def clone(self, path_id, new_execution_path, new_workflow, new_state):
return ExecutionPath(path_id,
new_execution_path,
new_workflow,
new_state,
[h.clone(self.path_id) for h in self.history],
self.ended)
def __repr__(self):
return f"(Path id={self.path_id}, workflow='{self.current_workflow}', state='{self.current_state}')"
def get_audit_trail(self):
return [h.from_state for h in self.history]
class StateMachine:
def __init__(self, workflows):
self.workflows = workflows
self.paths = None
self.last_path_id = -1
def run(self, workflow_name: str, state_name: str, execution_context):
"""
Run the workflow from the state given in parameter
:param workflow_name:
:type workflow_name:
:param state_name:
:type state_name:
:param execution_context:
:type execution_context:
:return:
:rtype:
"""
self.last_path_id = -1 # reset the path ids
self.paths = [ExecutionPath(self._get_new_path_id(),
execution_context,
workflow_name,
state_name,
[],
False)]
while True:
to_review = [p for p in self.paths if not p.ended]
if len(to_review) == 0:
break
for path in to_review:
# add traceability
history = ExecutionPathHistory(f"{path.current_workflow}:{path.current_state}",
path.execution_context.to_debug())
path.history.append(history)
current_state = self.workflows[path.current_workflow][path.current_state]
res = current_state.run(path.execution_context)
if res.next_state is None:
path.ended = True
continue # not possible to fork !
path.current_workflow, path.current_state = self._compute_next_workflow_and_state(path.current_workflow,
res.next_state)
# update traceability
history.to_state = f"{path.current_workflow}:{path.current_state}"
# add forks
if res.forks:
new_paths = []
for next_state, next_execution_context in res.forks:
next_workflow, next_state = self._compute_next_workflow_and_state(path.current_workflow,
next_state)
new_paths.append(path.clone(self._get_new_path_id(),
next_execution_context,
next_workflow,
next_state))
self.paths.extend(new_paths)
history.forks = [p.path_id for p in new_paths]
def _get_new_path_id(self):
self.last_path_id += 1
return self.last_path_id
@staticmethod
def _compute_next_workflow_and_state(workflow, state):
if state.startswith("#"):
return state, "start"
else:
return workflow, state