333 lines
12 KiB
Python
333 lines
12 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Any, Literal
|
|
|
|
from common.utils import str_concept
|
|
from core.ExecutionContext import ExecutionContext
|
|
from core.concept import ConceptMetadata
|
|
from parsers.ParserInput import ParserInput
|
|
from parsers.parser_utils import UnexpectedEof, UnexpectedToken, get_text_from_tokens
|
|
from parsers.tokenizer import Token
|
|
|
|
|
|
@dataclass
|
|
class MetadataToken:
|
|
"""
|
|
Class that represents a text that is recognized as a concept
|
|
We keep track of the start and the end position
|
|
"""
|
|
metadata: ConceptMetadata
|
|
start: int
|
|
end: int
|
|
resolution_method: Literal["name", "key", "id"]
|
|
parser: str
|
|
|
|
def __repr__(self):
|
|
return f"(MetadataToken metadata={str_concept(self.metadata, drop_name=True)}, " + \
|
|
f"start={self.start}, end={self.end}, method={self.resolution_method}, origin={self.parser})"
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, MetadataToken):
|
|
return False
|
|
|
|
return self.metadata.id == other.metadata.id \
|
|
and self.start == other.start \
|
|
and self.end == other.end \
|
|
and self.parser == other.parser
|
|
|
|
def __hash__(self):
|
|
return hash((self.metadata.id, self.start, self.end, self.parser))
|
|
|
|
|
|
@dataclass
|
|
class UnrecognizedToken:
|
|
"""
|
|
Class that represents a text that is not recognized
|
|
We keep track of the start and the end position
|
|
"""
|
|
buffer: str
|
|
start: int
|
|
end: int
|
|
|
|
|
|
@dataclass
|
|
class StateResult:
|
|
next_state: str | None
|
|
forks: list = None
|
|
|
|
|
|
@dataclass
|
|
class ConceptToRecognize:
|
|
"""
|
|
Holds information about the concept to recognize
|
|
"""
|
|
metadata: ConceptMetadata
|
|
expected_tokens: list
|
|
resolution_method: Literal["name", "key", "id"] # which attribute was used to resolve the concept
|
|
|
|
|
|
@dataclass
|
|
class StateMachineContext:
|
|
context: ExecutionContext
|
|
parser_input: ParserInput
|
|
get_metadata_from_first_token: Any
|
|
buffer: list[Token] = field(default_factory=list)
|
|
buffer_start_pos: int = -1
|
|
concept_to_recognize: ConceptToRecognize | None = None
|
|
result: list = field(default_factory=list)
|
|
errors: list = field(default_factory=list)
|
|
|
|
def get_clones(self, concepts_to_recognize):
|
|
return [StateMachineContext(self.context,
|
|
self.parser_input.clone(),
|
|
self.get_metadata_from_first_token,
|
|
self.buffer.copy(),
|
|
self.buffer_start_pos,
|
|
concept,
|
|
self.result.copy(),
|
|
self.errors.copy())
|
|
for concept in concepts_to_recognize]
|
|
|
|
def to_debug(self):
|
|
return {"pos": self.parser_input.pos,
|
|
"token": self.parser_input.token,
|
|
"buffer": [token.value for token in self.buffer],
|
|
"concept": str_concept(self.concept_to_recognize.metadata) if self.concept_to_recognize else None,
|
|
"result": self.result.copy()}
|
|
|
|
|
|
class State:
|
|
def __init__(self, name, next_states):
|
|
self.name = name
|
|
self.next_states = next_states
|
|
|
|
def run(self, state_context: StateMachineContext) -> StateResult:
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_forks(next_state, states_contexts: list[StateMachineContext]):
|
|
"""
|
|
Create on fork item for every state context
|
|
:param next_state:
|
|
:type next_state:
|
|
:param states_contexts:
|
|
:type states_contexts:
|
|
:return:
|
|
:rtype:
|
|
"""
|
|
return [(next_state, state_context) for state_context in states_contexts]
|
|
|
|
def __repr__(self):
|
|
return f"(State '{self.name}' -> {self.next_states})"
|
|
|
|
|
|
class Start(State):
|
|
def run(self, state_context) -> StateResult:
|
|
# Start state
|
|
# give some logs and ask for the next state
|
|
return StateResult(self.next_states[0])
|
|
|
|
def __repr__(self):
|
|
return f"(StartState '{self.name}' -> '{self.next_states[0]}')"
|
|
|
|
|
|
class PrepareReadTokens(State):
|
|
def run(self, state_context: StateMachineContext) -> StateResult:
|
|
state_context.buffer.clear()
|
|
state_context.buffer_start_pos = state_context.parser_input.pos + 1
|
|
return StateResult(self.next_states[0])
|
|
|
|
|
|
class ReadTokens(State):
|
|
def run(self, state_context) -> StateResult:
|
|
if not state_context.parser_input.next_token(False):
|
|
return StateResult("eof")
|
|
|
|
# try to get the possible concepts to recognize
|
|
concepts = state_context.get_metadata_from_first_token(state_context.context,
|
|
state_context.parser_input.token)
|
|
|
|
forks = self.get_forks("concepts found", state_context.get_clones(concepts)) if concepts else None
|
|
|
|
state_context.buffer.append(state_context.parser_input.token)
|
|
return StateResult(self.name, forks)
|
|
|
|
|
|
class ManageUnrecognized(State):
|
|
def run(self, state_context) -> StateResult:
|
|
if state_context.buffer:
|
|
buffer_as_str = get_text_from_tokens(state_context.buffer)
|
|
if len(state_context.result) > 0 and isinstance(old := state_context.result[-1], UnrecognizedToken):
|
|
state_context.result[-1] = UnrecognizedToken(old.buffer + buffer_as_str,
|
|
old.start,
|
|
state_context.parser_input.pos - 1)
|
|
else:
|
|
state_context.result.append(UnrecognizedToken(buffer_as_str,
|
|
state_context.buffer_start_pos,
|
|
state_context.parser_input.pos - 1))
|
|
|
|
return StateResult(self.next_states[0])
|
|
|
|
|
|
class ReadConcept(State):
|
|
def run(self, state_context) -> StateResult:
|
|
start = state_context.parser_input.pos
|
|
|
|
for expected in state_context.concept_to_recognize.expected_tokens:
|
|
if not state_context.parser_input.next_token(False):
|
|
# eof before the concept is recognized
|
|
state_context.errors.append(UnexpectedEof(expected, state_context.parser_input.token))
|
|
state_context.concept_to_recognize = None
|
|
return StateResult(self.next_states[0])
|
|
|
|
token = state_context.parser_input.token
|
|
if token.value != expected:
|
|
# token mismatch
|
|
state_context.errors.append(UnexpectedToken(token, expected))
|
|
state_context.concept_to_recognize = None
|
|
return StateResult(self.next_states[0])
|
|
|
|
state_context.result.append(MetadataToken(state_context.concept_to_recognize.metadata,
|
|
start,
|
|
state_context.parser_input.pos,
|
|
state_context.concept_to_recognize.resolution_method,
|
|
"simple"))
|
|
|
|
state_context.concept_to_recognize = None
|
|
return StateResult(self.next_states[0])
|
|
|
|
|
|
class End(State):
|
|
def run(self, state_context) -> StateResult:
|
|
return StateResult(None)
|
|
|
|
def __repr__(self):
|
|
return f"(EndState '{self.name}')"
|
|
|
|
|
|
@dataclass
|
|
class ExecutionPathHistory:
|
|
from_state: str
|
|
execution_context_debug: dict
|
|
to_state: str = ""
|
|
forks: list[tuple] = None
|
|
parents: list = None
|
|
|
|
def clone(self, parent_path_id):
|
|
parents = self.parents.copy() if self.parents else []
|
|
parents.append(parent_path_id)
|
|
return ExecutionPathHistory(self.from_state,
|
|
self.execution_context_debug.copy(),
|
|
self.to_state,
|
|
self.forks.copy() if self.forks else None,
|
|
parents)
|
|
|
|
def __repr__(self):
|
|
return "History(from '{0}', to '{1}', using {2}, forks={3}, parents={4}".format(
|
|
self.from_state,
|
|
self.to_state,
|
|
self.execution_context_debug,
|
|
len(self.forks) if self.forks else 0,
|
|
self.parents)
|
|
|
|
|
|
@dataclass
|
|
class ExecutionPath:
|
|
path_id: int
|
|
execution_context: Any
|
|
current_workflow: str
|
|
current_state: str
|
|
|
|
history: list[ExecutionPathHistory]
|
|
ended: bool = False
|
|
|
|
def clone(self, path_id, new_execution_path, new_workflow, new_state):
|
|
return ExecutionPath(path_id,
|
|
new_execution_path,
|
|
new_workflow,
|
|
new_state,
|
|
[h.clone(self.path_id) for h in self.history],
|
|
self.ended)
|
|
|
|
def __repr__(self):
|
|
return f"(Path id={self.path_id}, workflow='{self.current_workflow}', state='{self.current_state}')"
|
|
|
|
def get_audit_trail(self):
|
|
return [h.from_state for h in self.history]
|
|
|
|
|
|
class StateMachine:
|
|
|
|
def __init__(self, workflows):
|
|
self.workflows = workflows
|
|
self.paths = None
|
|
self.last_path_id = -1
|
|
|
|
def run(self, workflow_name: str, state_name: str, execution_context):
|
|
"""
|
|
Run the workflow from the state given in parameter
|
|
:param workflow_name:
|
|
:type workflow_name:
|
|
:param state_name:
|
|
:type state_name:
|
|
:param execution_context:
|
|
:type execution_context:
|
|
:return:
|
|
:rtype:
|
|
"""
|
|
self.last_path_id = -1 # reset the path ids
|
|
self.paths = [ExecutionPath(self._get_new_path_id(),
|
|
execution_context,
|
|
workflow_name,
|
|
state_name,
|
|
[],
|
|
False)]
|
|
|
|
while True:
|
|
to_review = [p for p in self.paths if not p.ended]
|
|
if len(to_review) == 0:
|
|
break
|
|
|
|
for path in to_review:
|
|
# add traceability
|
|
history = ExecutionPathHistory(f"{path.current_workflow}:{path.current_state}",
|
|
path.execution_context.to_debug())
|
|
path.history.append(history)
|
|
|
|
current_state = self.workflows[path.current_workflow][path.current_state]
|
|
res = current_state.run(path.execution_context)
|
|
|
|
if res.next_state is None:
|
|
path.ended = True
|
|
continue # not possible to fork !
|
|
|
|
path.current_workflow, path.current_state = self._compute_next_workflow_and_state(path.current_workflow,
|
|
res.next_state)
|
|
|
|
# update traceability
|
|
history.to_state = f"{path.current_workflow}:{path.current_state}"
|
|
|
|
# add forks
|
|
if res.forks:
|
|
new_paths = []
|
|
for next_state, next_execution_context in res.forks:
|
|
next_workflow, next_state = self._compute_next_workflow_and_state(path.current_workflow,
|
|
next_state)
|
|
new_paths.append(path.clone(self._get_new_path_id(),
|
|
next_execution_context,
|
|
next_workflow,
|
|
next_state))
|
|
|
|
self.paths.extend(new_paths)
|
|
history.forks = [p.path_id for p in new_paths]
|
|
|
|
def _get_new_path_id(self):
|
|
self.last_path_id += 1
|
|
return self.last_path_id
|
|
|
|
@staticmethod
|
|
def _compute_next_workflow_and_state(workflow, state):
|
|
if state.startswith("#"):
|
|
return state, "start"
|
|
else:
|
|
return workflow, state
|