Source code for mlrun.serving.states

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["TaskStep", "RouterStep", "RootFlowStep", "ErrorStep"]

import os
import pathlib
import traceback
from copy import copy, deepcopy
from inspect import getfullargspec, signature
from typing import Any, Union

import mlrun

from ..config import config
from ..datastore import get_stream_pusher
from ..datastore.utils import (
    get_kafka_brokers_from_dict,
    parse_kafka_url,
)
from ..errors import MLRunInvalidArgumentError, err_to_str
from ..model import ModelObj, ObjectDict
from ..platforms.iguazio import parse_path
from ..utils import get_class, get_function, is_explicit_ack_supported
from .utils import StepToDict, _extract_input_data, _update_result_body

callable_prefix = "_"
path_splitter = "/"
previous_step = "$prev"
queue_class_names = [">>", "$queue"]


class GraphError(Exception):
    """error in graph topology or configuration"""

    pass


class StepKinds:
    router = "router"
    task = "task"
    flow = "flow"
    queue = "queue"
    choice = "choice"
    root = "root"
    error_step = "error_step"


_task_step_fields = [
    "kind",
    "class_name",
    "class_args",
    "handler",
    "skip_context",
    "after",
    "function",
    "comment",
    "shape",
    "full_event",
    "on_error",
    "responder",
    "input_path",
    "result_path",
]


def new_model_endpoint(class_name, model_path, handler=None, **class_args):
    class_args = deepcopy(class_args)
    class_args["model_path"] = model_path
    return TaskStep(class_name, class_args, handler=handler)


def new_remote_endpoint(url, **class_args):
    class_args = deepcopy(class_args)
    class_args["url"] = url
    return TaskStep("$remote", class_args)


[docs]class BaseStep(ModelObj): kind = "BaseStep" default_shape = "ellipse" _dict_fields = ["kind", "comment", "after", "on_error"] def __init__(self, name: str = None, after: list = None, shape: str = None): self.name = name self._parent = None self.comment = None self.context = None self.after = after or [] self._next = None self.shape = shape self.on_error = None self._on_error_handler = None def get_shape(self): """graphviz shape""" return self.shape or self.default_shape def set_parent(self, parent): """set/link the step parent (flow/router)""" self._parent = parent @property def next(self): return self._next @property def parent(self): """step parent (flow/router)""" return self._parent def set_next(self, key: str): """set/insert the key as next after this step, optionally remove other keys""" if not self.next: self._next = [key] elif key not in self.next: self._next.append(key) return self def after_step(self, *after, append=True): """specify the previous step names""" # add new steps to the after list if not append: self.after = [] for name in after: # if its a step/task class (vs a str) extract its name name = name if isinstance(name, str) else name.name if name not in self.after: self.after.append(name) return self
[docs] def error_handler( self, name: str = None, class_name=None, handler=None, before=None, function=None, full_event: bool = None, input_path: str = None, result_path: str = None, **class_args, ): """set error handler on a step or the entire graph (to be executed on failure/raise) When setting the error_handler on the graph object, the graph completes after the error handler execution. example: in the below example, an 'error_catcher' step is set as the error_handler of the 'raise' step: in case of error/raise in 'raise' step, the handle_error will be run. after that, the 'echo' step will be run. graph = function.set_topology('flow', engine='async') graph.to(name='raise', handler='raising_step')\ .error_handler(name='error_catcher', handler='handle_error', full_event=True, before='echo') graph.add_step(name="echo", handler='echo', after="raise").respond() :param name: unique name (and path) for the error handler step, default is class name :param class_name: class name or step object to build the step from the error handler step is derived from task step (ie no router/queue functionally) :param handler: class/function handler to invoke on run/event :param before: string or list of next step(s) names that will run after this step. the `before` param must not specify upstream steps as it will cause a loop. if `before` is not specified, the graph will complete after the error handler execution. :param function: function this step should run in :param full_event: this step accepts the full event (not just the body) :param input_path: selects the key/path in the event to use as input to the step this requires that the event body will behave like a dict, for example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this requires that the event body will behave like a dict, for example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param class_args: class init arguments """ if not (class_name or handler): raise MLRunInvalidArgumentError("class_name or handler must be provided") if isinstance(self, RootFlowStep) and before: raise MLRunInvalidArgumentError( "`before` arg can't be specified for graph error handler" ) name = get_name(name, class_name) step = ErrorStep( class_name, class_args, handler, name=name, function=function, full_event=full_event, input_path=input_path, result_path=result_path, ) self.on_error = name before = [before] if isinstance(before, str) else before step.before = before or [] step.base_step = self.name if hasattr(self, "_parent") and self._parent: # when self is a step step = self._parent._steps.update(name, step) step.set_parent(self._parent) else: # when self is the graph step = self._steps.update(name, step) step.set_parent(self) return self
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): """init the step class""" self.context = context def _is_local_function(self, context): return True def get_children(self): """get child steps (for router/flow)""" return [] def __iter__(self): yield from [] @property def fullname(self): """full path/name (include parents)""" name = self.name or "" if self._parent and self._parent.fullname: name = path_splitter.join([self._parent.fullname, name]) return name.replace(":", "_") # replace for graphviz escaping def _post_init(self, mode="sync"): pass def _set_error_handler(self): """init/link the error handler for this step""" if self.on_error: error_step = self.context.root.path_to_step(self.on_error) self._on_error_handler = error_step.run def _log_error(self, event, err, **kwargs): """on failure log (for sync mode)""" error_message = err_to_str(err) self.context.logger.error( f"step {self.name} got error {error_message} when processing an event:\n {event.body}" ) error_trace = traceback.format_exc() self.context.logger.error(error_trace) self.context.push_error( event, f"{error_message}\n{error_trace}", source=self.fullname, **kwargs ) def _call_error_handler(self, event, err, **kwargs): """call the error handler if exist""" if not event.error: event.error = {} event.error[self.name] = err_to_str(err) event.origin_state = self.fullname return self._on_error_handler(event) def path_to_step(self, path: str): """return step object from step relative/fullname""" path = path or "" tree = path.split(path_splitter) next_level = self for step in tree: if step not in next_level: raise GraphError( f"step {step} doesnt exist in the graph under {next_level.fullname}" ) next_level = next_level[step] return next_level
[docs] def to( self, class_name: Union[str, StepToDict] = None, name: str = None, handler: str = None, graph_shape: str = None, function: str = None, full_event: bool = None, input_path: str = None, result_path: str = None, **class_args, ): """add a step right after this step and return the new step example: a 4-step pipeline ending with a stream: graph.to('URLDownloader')\ .to('ToParagraphs')\ .to(name='to_json', handler='json.dumps')\ .to('>>', 'to_v3io', path=stream_path)\ :param class_name: class name or step object to build the step from for router steps the class name should start with '*' for queue/stream step the class should be '>>' or '$queue' :param name: unique name (and path) for the child step, default is class name :param handler: class/function handler to invoke on run/event :param graph_shape: graphviz shape name :param function: function this step should run in :param full_event: this step accepts the full event (not just body) :param input_path: selects the key/path in the event to use as input to the step this requires that the event body will behave like a dict, example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this require that the event body will behave like a dict, example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param class_args: class init arguments """ if hasattr(self, "steps"): parent = self elif self._parent: parent = self._parent else: raise GraphError( f"step {self.name} parent is not set or it's not part of a graph" ) name, step = params_to_step( class_name, name, handler, graph_shape=graph_shape, function=function, full_event=full_event, input_path=input_path, result_path=result_path, class_args=class_args, ) step = parent._steps.update(name, step) step.set_parent(parent) if not hasattr(self, "steps"): # check that its not the root, todo: in future may gave nested flows step.after_step(self.name) parent._last_added = step return step
def set_flow( self, steps: list[Union[str, StepToDict, dict[str, Any]]], force: bool = False, ): """set list of steps as downstream from this step, in the order specified. This will overwrite any existing downstream steps. :param steps: list of steps to follow this one :param force: whether to overwrite existing downstream steps. If False, this method will fail if any downstream steps have already been defined. Defaults to False. :return: the last step added to the flow example: The below code sets the downstream nodes of step1 by using a list of steps (provided to `set_flow()`) and a single step (provided to `to()`), resulting in the graph (step1 -> step2 -> step3 -> step4). Notice that using `force=True` is required in case step1 already had downstream nodes (e.g. if the existing graph is step1 -> step2_old) and that following the execution of this code the existing downstream steps are removed. If the intention is to split the graph (and not to overwrite), please use `to()`. step1.set_flow( [ dict(name="step2", handler="step2_handler"), dict(name="step3", class_name="Step3Class"), ], force=True, ).to(dict(name="step4", class_name="Step4Class")) """ raise NotImplementedError("set_flow() can only be called on a FlowStep")
[docs]class TaskStep(BaseStep): """task execution step, runs a class or handler""" kind = "task" _dict_fields = _task_step_fields _default_class = "" def __init__( self, class_name: Union[str, type] = None, class_args: dict = None, handler: str = None, name: str = None, after: list = None, full_event: bool = None, function: str = None, responder: bool = None, input_path: str = None, result_path: str = None, ): super().__init__(name, after) self.class_name = class_name self.class_args = class_args or {} self.handler = handler self.function = function self._handler = None self._object = None self._async_object = None self.skip_context = None self.context = None self._class_object = None self.responder = responder self.full_event = full_event self.input_path = input_path self.result_path = result_path self.on_error = None self._inject_context = False self._call_with_event = False
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context self._async_object = None if not self._is_local_function(context): # skip init of non local functions return if self.handler and not self.class_name: # link to function if callable(self.handler): self._handler = self.handler self.handler = self.handler.__name__ else: self._handler = get_function(self.handler, namespace) args = signature(self._handler).parameters if args and "context" in list(args.keys()): self._inject_context = True self._set_error_handler() return self._class_object, self.class_name = self.get_step_class_object( namespace=namespace ) if not self._object or reset: # init the step class + args extracted_class_args = self.get_full_class_args( namespace=namespace, class_object=self._class_object, **extra_kwargs, ) try: self._object = self._class_object(**extracted_class_args) except TypeError as exc: raise TypeError( f"failed to init step {self.name}\n args={self.class_args}" ) from exc # determine the right class handler to use handler = self.handler if handler: if not hasattr(self._object, handler): raise GraphError( f"handler ({handler}) specified but doesnt exist in class {self.class_name}" ) else: if hasattr(self._object, "do_event"): handler = "do_event" self._call_with_event = True elif hasattr(self._object, "do"): handler = "do" if handler: self._handler = getattr(self._object, handler, None) self._set_error_handler() if mode != "skip": self._post_init(mode)
[docs] def get_full_class_args(self, namespace, class_object, **extra_kwargs): class_args = {} for key, arg in self.class_args.items(): if key.startswith(callable_prefix): class_args[key[1:]] = get_function(arg, namespace) else: class_args[key] = arg class_args.update(extra_kwargs) # add common args (name, context, ..) only if target class can accept them argspec = getfullargspec(class_object) for key in ["name", "context", "input_path", "result_path", "full_event"]: if argspec.varkw or key in argspec.args: class_args[key] = getattr(self, key) if argspec.varkw or "graph_step" in argspec.args: class_args["graph_step"] = self return class_args
[docs] def get_step_class_object(self, namespace): class_name = self.class_name class_object = self._class_object if isinstance(class_name, type): class_object = class_name class_name = class_name.__name__ elif not class_object: if class_name == "$remote": from mlrun.serving.remote import RemoteStep class_object = RemoteStep else: class_object = get_class(class_name or self._default_class, namespace) return class_object, class_name
def _is_local_function(self, context): # detect if the class is local (and should be initialized) current_function = get_current_function(context) if current_function == "*": return True if not self.function and not current_function: return True if ( self.function and self.function == "*" ) or self.function == current_function: return True return False @property def async_object(self): """return the sync or async (storey) class instance""" return self._async_object or self._object
[docs] def clear_object(self): self._object = None
def _post_init(self, mode="sync"): if self._object and hasattr(self._object, "post_init"): self._object.post_init(mode) if hasattr(self._object, "model_endpoint_uid"): self.endpoint_uid = self._object.model_endpoint_uid
[docs] def respond(self): """mark this step as the responder. step output will be returned as the flow result, no other step can follow """ self.responder = True return self
[docs] def run(self, event, *args, **kwargs): """run this step, in async flows the run is done through storey""" if not self._is_local_function(self.context): # todo invoke remote via REST call return event if self.context.verbose: self.context.logger.info(f"step {self.name} got event {event.body}") # inject context parameter if it is expected by the handler if self._inject_context: kwargs["context"] = self.context elif kwargs and "context" in kwargs: del kwargs["context"] try: if self.full_event or self._call_with_event: return self._handler(event, *args, **kwargs) if self._handler is None: raise MLRunInvalidArgumentError( f"step {self.name} does not have a handler" ) result = self._handler( _extract_input_data(self.input_path, event.body), *args, **kwargs ) event.body = _update_result_body(self.result_path, event.body, result) except Exception as exc: if self._on_error_handler: self._log_error(event, exc) result = self._call_error_handler(event, exc) event.body = _update_result_body(self.result_path, event.body, result) else: raise exc return event
[docs]class ErrorStep(TaskStep): """error execution step, runs a class or handler""" kind = "error_step" _dict_fields = _task_step_fields + ["before", "base_step"] _default_class = "" def __init__( self, class_name: Union[str, type] = None, class_args: dict = None, handler: str = None, name: str = None, after: list = None, full_event: bool = None, function: str = None, responder: bool = None, input_path: str = None, result_path: str = None, ): super().__init__( class_name=class_name, class_args=class_args, handler=handler, name=name, after=after, full_event=full_event, function=function, responder=responder, input_path=input_path, result_path=result_path, ) self.before = None self.base_step = None
[docs]class RouterStep(TaskStep): """router step, implement routing logic for running child routes""" kind = "router" default_shape = "doubleoctagon" _dict_fields = _task_step_fields + ["routes"] _default_class = "mlrun.serving.ModelRouter" def __init__( self, class_name: Union[str, type] = None, class_args: dict = None, handler: str = None, routes: list = None, name: str = None, function: str = None, input_path: str = None, result_path: str = None, ): super().__init__( class_name, class_args, handler, name=name, function=function, input_path=input_path, result_path=result_path, ) self._routes: ObjectDict = None self.routes = routes
[docs] def get_children(self): """get child steps (routes)""" return self._routes.values()
@property def routes(self): """child routes/steps, traffic is routed to routes based on router logic""" return self._routes @routes.setter def routes(self, routes: dict): self._routes = ObjectDict.from_dict(classes_map, routes, "task")
[docs] def add_route( self, key, route=None, class_name=None, handler=None, function=None, **class_args, ): """add child route step or class to the router :param key: unique name (and route path) for the child step :param route: child step object (Task, ..) :param class_name: class name to build the route step from (when route is not provided) :param class_args: class init arguments :param handler: class handler to invoke on run/event :param function: function this step should run in """ if not route and not class_name and not handler: raise MLRunInvalidArgumentError("route or class_name must be specified") if not route: route = TaskStep(class_name, class_args, handler=handler) route.function = function or route.function route = self._routes.update(key, route) route.set_parent(self) return route
[docs] def clear_children(self, routes: list): """clear child steps (routes)""" if not routes: routes = self._routes.keys() for key in routes: del self._routes[key]
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): if not self._is_local_function(context): return self.class_args = self.class_args or {} super().init_object( context, namespace, "skip", reset=reset, routes=self._routes, **extra_kwargs ) for route in self._routes.values(): if self.function and not route.function: # if the router runs on a child function and the # model function is not specified use the router function route.function = self.function route.set_parent(self) route.init_object(context, namespace, mode, reset=reset) self._set_error_handler() self._post_init(mode)
def __getitem__(self, name): return self._routes[name] def __setitem__(self, name, route): self.add_route(name, route) def __delitem__(self, key): del self._routes[key] def __iter__(self): yield from self._routes.keys()
[docs] def plot(self, filename=None, format=None, source=None, **kw): """plot/save graph using graphviz :param filename: target filepath for the image (None for the notebook) :param format: The output format used for rendering (``'pdf'``, ``'png'``, etc.) :param source: source step to add to the graph :param kw: kwargs passed to graphviz, e.g. rankdir="LR" (see: https://graphviz.org/doc/info/attrs.html) :return: graphviz graph object """ return _generate_graphviz( self, _add_graphviz_router, filename, format, source=source, **kw )
[docs]class QueueStep(BaseStep): """queue step, implement an async queue or represent a stream""" kind = "queue" default_shape = "cds" _dict_fields = BaseStep._dict_fields + [ "path", "shards", "retention_in_hours", "trigger_args", "options", ] def __init__( self, name: str = None, path: str = None, after: list = None, shards: int = None, retention_in_hours: int = None, trigger_args: dict = None, **options, ): super().__init__(name, after) self.path = path self.shards = shards self.retention_in_hours = retention_in_hours self.options = options self.trigger_args = trigger_args self._stream = None self._async_object = None
[docs] def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context if self.path: self._stream = get_stream_pusher( self.path, shards=self.shards, retention_in_hours=self.retention_in_hours, **self.options, ) self._set_error_handler()
@property def async_object(self): return self._async_object
[docs] def run(self, event, *args, **kwargs): data = event.body if not data: return event if self._stream: self._stream.push({"id": event.id, "body": data, "path": event.path}) event.terminated = True event.body = None return event
class FlowStep(BaseStep): """flow step, represent a workflow or DAG""" kind = "flow" _dict_fields = BaseStep._dict_fields + [ "steps", "engine", "default_final_step", ] def __init__( self, name=None, steps=None, after: list = None, engine=None, final_step=None, ): super().__init__(name, after) self._steps = None self.steps = steps self.engine = engine self.from_step = os.environ.get("START_FROM_STEP", None) self.final_step = final_step self._last_added = None self._controller = None self._wait_for_result = False self._source = None self._start_steps = [] def get_children(self): return self._steps.values() @property def steps(self): """child (workflow) steps""" return self._steps @property def controller(self): """async (storey) flow controller""" return self._controller @steps.setter def steps(self, steps): self._steps = ObjectDict.from_dict(classes_map, steps, "task") def add_step( self, class_name=None, name=None, handler=None, after=None, before=None, graph_shape=None, function=None, full_event: bool = None, input_path: str = None, result_path: str = None, **class_args, ): """add task, queue or router step/class to the flow use after/before to insert into a specific location example: graph = fn.set_topology("flow", exist_ok=True) graph.add_step(class_name="Chain", name="s1") graph.add_step(class_name="Chain", name="s3", after="$prev") graph.add_step(class_name="Chain", name="s2", after="s1", before="s3") :param class_name: class name or step object to build the step from for router steps the class name should start with '*' for queue/stream step the class should be '>>' or '$queue' :param name: unique name (and path) for the child step, default is class name :param handler: class/function handler to invoke on run/event :param after: the step name this step comes after can use $prev to indicate the last added step :param before: string or list of next step names that will run after this step :param graph_shape: graphviz shape name :param function: function this step should run in :param full_event: this step accepts the full event (not just body) :param input_path: selects the key/path in the event to use as input to the step this require that the event body will behave like a dict, example: event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will receive 7 as input :param result_path: selects the key/path in the event to write the results to this require that the event body will behave like a dict, example: event: {"x": 5} , result_path="y" means the output of the step will be written to event["y"] resulting in {"x": 5, "y": <result>} :param class_args: class init arguments """ name, step = params_to_step( class_name, name, handler, graph_shape=graph_shape, function=function, full_event=full_event, input_path=input_path, result_path=result_path, class_args=class_args, ) after_list = after if isinstance(after, list) else [after] for after in after_list: self.insert_step(name, step, after, before) return step def insert_step(self, key, step, after, before=None): """insert step object into the flow, specify before and after""" step = self._steps.update(key, step) step.set_parent(self) if after == "$prev" and len(self._steps) == 1: after = None previous = "" if after: if after == "$prev" and self._last_added: previous = self._last_added.name else: if after not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set after, there is no step named {after}" ) previous = after step.after_step(previous) if before: if before not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set before, there is no step named {before}" ) if before == step.name or before == previous: raise GraphError( f"graph loop, step {before} is specified in before and/or after {key}" ) self[step.name].after_step(*self[before].after, append=False) self[before].after_step(step.name, append=False) self._last_added = step return step def clear_children(self, steps: list = None): """remove some or all of the states, empty/None for all""" if not steps: steps = self._steps.keys() for key in steps: del self._steps[key] def __getitem__(self, name): return self._steps[name] def __setitem__(self, name, step): self.add_step(name, step) def __delitem__(self, key): del self._steps[key] def __iter__(self): yield from self._steps.keys() def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): """initialize graph objects and classes""" self.context = context self._insert_all_error_handlers() self.check_and_process_graph() for step in self._steps.values(): step.set_parent(self) step.init_object(context, namespace, mode, reset=reset) self._set_error_handler() self._post_init(mode) if self.engine != "sync": self._build_async_flow() self._run_async_flow() def check_and_process_graph(self, allow_empty=False): """validate correct graph layout and initialize the .next links""" if self.is_empty() and allow_empty: self._start_steps = [] return [], None, [] def has_loop(step, previous): for next_step in step.after or []: if next_step in previous: return step.name downstream = has_loop(self[next_step], previous + [next_step]) if downstream: return downstream return None start_steps = [] for step in self._steps.values(): step._next = None step._visited = False if step.after: loop_step = has_loop(step, []) if loop_step: raise GraphError( f"Error, loop detected in step {loop_step}, graph must be acyclic (DAG)" ) else: start_steps.append(step.name) responders = [] for step in self._steps.values(): if ( hasattr(step, "responder") and step.responder and step.kind != "error_step" ): responders.append(step.name) if step.on_error and step.on_error in start_steps: start_steps.remove(step.on_error) if step.after: for prev_step in step.after: self[prev_step].set_next(step.name) if self.on_error and self.on_error in start_steps: start_steps.remove(self.on_error) if ( len(responders) > 1 ): # should not have multiple steps which respond to request raise GraphError( f'there are more than one responder steps in the graph ({",".join(responders)})' ) if self.from_step: if self.from_step not in self.steps: raise GraphError( f"from_step ({self.from_step}) specified and not found in graph steps" ) start_steps = [self.from_step] self._start_steps = [self[name] for name in start_steps] def get_first_function_step(step, current_function): # find the first step which belongs to the function if ( hasattr(step, "function") and step.function and step.function == current_function ): return step for item in step.next or []: next_step = self[item] returned_step = get_first_function_step(next_step, current_function) if returned_step: return returned_step current_function = get_current_function(self.context) if current_function and current_function != "*": new_start_steps = [] for from_step in self._start_steps: step = get_first_function_step(from_step, current_function) if step: new_start_steps.append(step) if not new_start_steps: raise GraphError( f"did not find steps pointing to current function ({current_function})" ) self._start_steps = new_start_steps if self.engine == "sync" and len(self._start_steps) > 1: raise GraphError( "sync engine can only have one starting step (without .after)" ) default_final_step = None if self.final_step: if self.final_step not in self.steps: raise GraphError( f"final_step ({self.final_step}) specified and not found in graph steps" ) default_final_step = self.final_step elif len(self._start_steps) == 1: # find the final step in case if a simple sequence of steps next_obj = self._start_steps[0] while next_obj: next = next_obj.next if not next: default_final_step = next_obj.name break next_obj = self[next[0]] if len(next) == 1 else None return self._start_steps, default_final_step, responders def set_flow_source(self, source): """set the async flow (storey) source""" self._source = source def _build_async_flow(self): """initialize and build the async/storey DAG""" def process_step(state, step, root): if not state._is_local_function(self.context) or state._visited: return for item in state.next or []: next_state = root[item] if next_state.async_object: next_step = step.to(next_state.async_object) process_step(next_state, next_step, root) state._visited = ( True # mark visited to avoid re-visit in case of multiple uplinks ) default_source, self._wait_for_result = _init_async_objects( self.context, self._steps.values() ) source = self._source or default_source for next_state in self._start_steps: next_step = source.to(next_state.async_object) process_step(next_state, next_step, self) for step in self._steps.values(): # add error handler hooks if (step.on_error or self.on_error) and step.async_object: error_step = self._steps[step.on_error or self.on_error] # never set a step as its own error handler if step != error_step: step.async_object.set_recovery_step(error_step.async_object) for next_step in error_step.next or []: next_state = self[next_step] if next_state.async_object and error_step.async_object: error_step.async_object.to(next_state.async_object) self._async_flow = source def _run_async_flow(self): self._controller = self._async_flow.run() def get_queue_links(self): """return dict of function and queue its listening on, for building stream triggers""" links = {} for step in self.get_children(): if step.kind == StepKinds.queue: for item in step.next or []: next_step = self[item] if not next_step.function: raise GraphError( f"child function name must be specified in steps ({next_step.name}) which follow a queue" ) if next_step.function in links: raise GraphError( f"function ({next_step.function}) cannot read from multiple queues" ) links[next_step.function] = step return links def init_queues(self): """init/create the streams used in this flow""" for step in self.get_children(): if step.kind == StepKinds.queue: step.init_object(self.context, None) def list_child_functions(self): """return a list of child function names referred to in the steps""" functions = [] for step in self.get_children(): if ( hasattr(step, "function") and step.function and step.function not in functions ): functions.append(step.function) return functions def is_empty(self): """is the graph empty (no child steps)""" return len(self.steps) == 0 @staticmethod async def _await_and_return_id(awaitable, event): await awaitable event = copy(event) event.body = {"id": event.id} return event def run(self, event, *args, **kwargs): if self._controller: # async flow (using storey) event._awaitable_result = None resp = self._controller.emit( event, return_awaitable_result=self._wait_for_result ) if self._wait_for_result and resp: return resp.await_result() event = copy(event) event.body = {"id": event.id} return event if len(self._start_steps) == 0: return event next_obj = self._start_steps[0] while next_obj: try: event = next_obj.run(event, *args, **kwargs) except Exception as exc: if self._on_error_handler: self._log_error(event, exc, failed_step=next_obj.name) event.body = self._call_error_handler(event, exc) event.terminated = True return event else: raise exc if hasattr(event, "terminated") and event.terminated: return event if ( hasattr(event, "error") and isinstance(event.error, dict) and next_obj.name in event.error ): next_obj = self._steps[next_obj.on_error] next = next_obj.next if next and len(next) > 1: raise GraphError( f"synchronous flow engine doesnt support branches use async, step={next_obj.name}" ) next_obj = self[next[0]] if next else None return event def wait_for_completion(self): """wait for completion of run in async flows""" if self._controller: if hasattr(self._controller, "terminate"): self._controller.terminate() return self._controller.await_termination() def plot(self, filename=None, format=None, source=None, targets=None, **kw): """plot/save graph using graphviz :param filename: target filepath for the graph image (None for the notebook) :param format: the output format used for rendering (``'pdf'``, ``'png'``, etc.) :param source: source step to add to the graph image :param targets: list of target steps to add to the graph image :param kw: kwargs passed to graphviz, e.g. rankdir="LR" (see https://graphviz.org/doc/info/attrs.html) :return: graphviz graph object """ return _generate_graphviz( self, _add_graphviz_flow, filename, format, source=source, targets=targets, **kw, ) def _insert_all_error_handlers(self): """ insert all error steps to the graph run after deployment """ for name, step in self._steps.items(): if step.kind == "error_step": self._insert_error_step(name, step) def _insert_error_step(self, name, step): """ insert error step to the graph run after deployment """ if not step.before and not any( [step.name in other_step.after for other_step in self._steps.values()] ): step.responder = True return for step_name in step.before: if step_name not in self._steps.keys(): raise MLRunInvalidArgumentError( f"cant set before, there is no step named {step_name}" ) self[step_name].after_step(name) def set_flow( self, steps: list[Union[str, StepToDict, dict[str, Any]]], force: bool = False, ): if not force and self.steps: raise mlrun.errors.MLRunInvalidArgumentError( "set_flow() called on a step that already has downstream steps. " "If you want to overwrite existing steps, set force=True." ) self.steps = None step = self for next_step in steps: if isinstance(next_step, dict): step = step.to(**next_step) else: step = step.to(next_step) return step class RootFlowStep(FlowStep): """root flow step""" kind = "root" _dict_fields = ["steps", "engine", "final_step", "on_error"] classes_map = { "task": TaskStep, "router": RouterStep, "flow": FlowStep, "queue": QueueStep, "error_step": ErrorStep, } def get_current_function(context): if context and hasattr(context, "current_function"): return context.current_function or "" return "" def _add_graphviz_router(graph, step, source=None, **kwargs): if source: graph.node("_start", source.name, shape=source.shape, style="filled") graph.edge("_start", step.fullname) graph.node(step.fullname, label=step.name, shape=step.get_shape()) for route in step.get_children(): graph.node(route.fullname, label=route.name, shape=route.get_shape()) graph.edge(step.fullname, route.fullname) def _add_graphviz_flow( graph, step, source=None, targets=None, ): start_steps, default_final_step, responders = step.check_and_process_graph( allow_empty=True ) graph.node("_start", source.name, shape=source.shape, style="filled") for start_step in start_steps: graph.edge("_start", start_step.fullname) for child in step.get_children(): kind = child.kind if kind == StepKinds.router: with graph.subgraph(name="cluster_" + child.fullname) as sg: _add_graphviz_router(sg, child) else: graph.node(child.fullname, label=child.name, shape=child.get_shape()) _add_edges(child.after or [], step, graph, child) _add_edges(getattr(child, "before", []), step, graph, child, after=False) if child.on_error: graph.edge(child.fullname, child.on_error, style="dashed") # draw targets after the last step (if specified) if targets: for target in targets or []: target_kind, target_name = target.name.split("/", 1) if target_kind != target_name: label = ( f"<{target_name}<br/><font point-size='8'>({target_kind})</font>>" ) else: label = target_name graph.node(target.fullname, label=label, shape=target.get_shape()) last_step = target.after or default_final_step if last_step: graph.edge(last_step, target.fullname) def _add_edges(items, step, graph, child, after=True): for item in items: next_or_prev_object = step[item] kw = {} if next_or_prev_object.kind == StepKinds.router: kw["ltail"] = f"cluster_{next_or_prev_object.fullname}" if after: graph.edge(next_or_prev_object.fullname, child.fullname, **kw) else: graph.edge(child.fullname, next_or_prev_object.fullname, **kw) def _generate_graphviz( step, renderer, filename=None, format=None, source=None, targets=None, **kw, ): try: from graphviz import Digraph except ImportError: raise ImportError( 'graphviz is not installed, run "pip install graphviz" first!' ) graph = Digraph("mlrun-flow", format="jpg") graph.attr(compound="true", **kw) source = source or BaseStep("start", shape="egg") renderer(graph, step, source=source, targets=targets) if filename: suffix = pathlib.Path(filename).suffix if suffix: filename = filename[: -len(suffix)] format = format or suffix[1:] format = format or "png" graph.render(filename, format=format) return graph def graph_root_setter(server, graph): """set graph root object from class or dict""" if graph: if isinstance(graph, dict): kind = graph.get("kind") elif hasattr(graph, "kind"): kind = graph.kind else: raise MLRunInvalidArgumentError("graph must be a dict or a valid object") if kind == StepKinds.router: server._graph = server._verify_dict(graph, "graph", RouterStep) elif not kind or kind == StepKinds.root: server._graph = server._verify_dict(graph, "graph", RootFlowStep) else: raise GraphError(f"illegal root step {kind}") def get_name(name, class_name): """get task name from provided name or class""" if name: return name if not class_name: raise MLRunInvalidArgumentError("name or class_name must be provided") if isinstance(class_name, type): return class_name.__name__ return class_name def params_to_step( class_name, name, handler=None, graph_shape=None, function=None, full_event=None, input_path: str = None, result_path: str = None, class_args=None, ): """return step object from provided params or classes/objects""" class_args = class_args or {} if class_name and hasattr(class_name, "to_dict"): struct = class_name.to_dict() kind = struct.get("kind", StepKinds.task) name = name or struct.get("name", struct.get("class_name")) cls = classes_map.get(kind, RootFlowStep) step = cls.from_dict(struct) step.function = function step.full_event = full_event or step.full_event step.input_path = input_path or step.input_path step.result_path = result_path or step.result_path elif class_name and class_name in queue_class_names: if "path" not in class_args: raise MLRunInvalidArgumentError( "path=<stream path or None> must be specified for queues" ) if not name: raise MLRunInvalidArgumentError("queue name must be specified") # Pass full_event on only if it's explicitly defined if full_event is not None: class_args = class_args.copy() class_args["full_event"] = full_event step = QueueStep(name, **class_args) elif class_name and class_name.startswith("*"): routes = class_args.get("routes", None) class_name = class_name[1:] name = get_name(name, class_name or "router") step = RouterStep( class_name, class_args, handler, name=name, function=function, routes=routes, input_path=input_path, result_path=result_path, ) elif class_name or handler: name = get_name(name, class_name) step = TaskStep( class_name, class_args, handler, name=name, function=function, full_event=full_event, input_path=input_path, result_path=result_path, ) else: raise MLRunInvalidArgumentError("class_name or handler must be provided") if graph_shape: step.shape = graph_shape return name, step def _init_async_objects(context, steps): try: import storey except ImportError: raise GraphError("storey package is not installed, use pip install storey") wait_for_result = False trigger = getattr(context, "trigger", None) context.logger.debug(f"trigger is {trigger or 'unknown'}") # respond is only supported for HTTP trigger respond_supported = trigger is None or trigger == "http" for step in steps: if hasattr(step, "async_object") and step._is_local_function(context): if step.kind == StepKinds.queue: skip_stream = context.is_mock and step.next if step.path and not skip_stream: stream_path = step.path endpoint = None options = {} options.update(step.options) kafka_brokers = get_kafka_brokers_from_dict(options, pop=True) if stream_path.startswith("kafka://") or kafka_brokers: topic, brokers = parse_kafka_url(stream_path, kafka_brokers) kafka_producer_options = options.pop( "kafka_producer_options", None ) step._async_object = storey.KafkaTarget( topic=topic, brokers=brokers, producer_options=kafka_producer_options, context=context, **options, ) else: if stream_path.startswith("v3io://"): endpoint, stream_path = parse_path(step.path) stream_path = stream_path.strip("/") step._async_object = storey.StreamTarget( storey.V3ioDriver(endpoint or config.v3io_api), stream_path, context=context, **options, ) else: step._async_object = storey.Map(lambda x: x) elif not step.async_object or not hasattr(step.async_object, "_outlets"): # if regular class, wrap with storey Map step._async_object = storey.Map( step._handler, full_event=step.full_event or step._call_with_event, input_path=step.input_path, result_path=step.result_path, name=step.name, context=context, pass_context=step._inject_context, ) if ( respond_supported and not step.next and hasattr(step, "responder") and step.responder ): # if responder step (return result), add Complete() step.async_object.to(storey.Complete(full_event=True)) wait_for_result = True source_args = context.get_param("source_args", {}) explicit_ack = is_explicit_ack_supported(context) and mlrun.mlconf.is_explicit_ack() # TODO: Change to AsyncEmitSource once we can drop support for nuclio<1.12.10 default_source = storey.SyncEmitSource( context=context, explicit_ack=explicit_ack, **source_args, ) return default_source, wait_for_result