Source code for qlauncher.workflow.workflow_manager

  1import concurrent.futures
  2from collections.abc import Callable
  3from typing import Any, Literal
  4
  5from qlauncher.base import Algorithm, Backend, Model, Problem
  6
  7
[docs] 8class Task: 9 def __init__(self, func: Callable, args: tuple[Any] | None = None, kwargs: dict[str, Any] | None = None, num_output: int = 1): 10 if args is None: 11 args = tuple() 12 if kwargs is None: 13 kwargs = {} 14 self.func = func 15 self.dependencies: list[Task] = [arg for arg in args if isinstance(arg, Task)] 16 self.dependencies.extend([value for value in kwargs.values() if isinstance(value, Task)]) 17 self.args = args 18 self.kwargs = kwargs 19 self.done = False 20 self.result = None 21 self.num_output = num_output 22 self.subtasks: list[SubTask] = [] 23
[docs] 24 def run(self) -> None: 25 binded_args = [arg.result if isinstance(arg, Task) else arg for arg in self.args] 26 binded_kwargs = {key: (value.result if isinstance(value, Task) else value) for key, value in self.kwargs.items()} 27 self.result = self.func(*binded_args, **binded_kwargs) 28 self.done = True
29
[docs] 30 def is_ready(self): 31 return all(map(lambda x: x.done, self.dependencies))
32 33 def __iter__(self): 34 for i in range(self.num_output): 35 yield SubTask(self, i)
36 37
[docs] 38class SubTask(Task): 39 def __init__(self, task: Task, index: int): 40 self.task = task 41 self.index = index 42 43 @property 44 def result(self): 45 return self.task.result[self.index] 46 47 @property 48 def done(self): 49 return self.task.done
50 51
[docs] 52class Workflow(Algorithm): 53 def __init__(self, tasks: list[Task], input_task: Task, output_task: Task, input_format: type[Problem | Model]): 54 self.tasks = tasks 55 self.input_task = input_task 56 self.output_task = output_task 57 self.input_format = input_format 58
[docs] 59 def run(self, problem: Algorithm, backend: Backend) -> Any: # noqa: ANN401 60 self.input_task.result = problem 61 with concurrent.futures.ThreadPoolExecutor() as executor: 62 _execute_workflow(self.tasks, executor) 63 return self.output_task.result
64
[docs] 65 def get_input_format(self) -> type[Model]: 66 if issubclass(self.input_format, Problem): 67 return Model 68 return self.input_format
69 70
[docs] 71class WorkflowManager: 72 def __init__(self, manager: Literal['ql', 'prefect', 'airflow'] = 'ql'): 73 self.tasks: list[Task] = [] 74 self.manager = manager 75 self.input_task: Task | None = None 76 self.input_task_format: type[Problem | Model] = Model 77 self.output_task: Task | None = None 78 79 def __enter__(self): 80 return self 81 82 def __exit__(self, exc_type, exc_val, exc_tb): 83 pass 84
[docs] 85 def task(self, func, args: tuple | None = None, kwargs: dict | None = None, num_output=None) -> Task: 86 args = args or tuple() 87 kwargs = kwargs or dict() 88 new_task = Task(func, args, kwargs, num_output=num_output) 89 self.tasks.append(new_task) 90 return new_task
91 92 def __call__(self, input_value=None, /): 93 if self.input_task: 94 self.input_task.result = input_value 95 with concurrent.futures.ThreadPoolExecutor() as executor: 96 _execute_workflow(self.tasks, executor) 97 if self.output_task: 98 return self.output_task.result 99 return None 100
[docs] 101 def print_dag(self) -> None: 102 for task in self.tasks: 103 dep_names = [dep.func.__name__ for dep in task.dependencies] 104 print(f'{task.func.__name__} -> {dep_names}')
105
[docs] 106 def input(self, format: type[Problem | Model]): 107 self.input_task = Task(func=None) 108 self.input_task.done = True 109 self.input_task_format = format 110 return self.input_task
111
[docs] 112 def output(self, task: Task) -> None: 113 self.output_task = task
114
[docs] 115 def to_workflow(self) -> Workflow: 116 return Workflow(self.tasks, self.input_task, self.output_task, input_format=self.input_task_format)
117 118 119def _execute_workflow(tasks: list[Task], executor: concurrent.futures.Executor, max_iterations: int | None = None) -> None: 120 remaining_tasks = set(tasks) 121 max_iterations: int = max_iterations or len(remaining_tasks) 122 iteration = 0 123 for _ in range(max_iterations): 124 ready_tasks = list(filter(Task.is_ready, remaining_tasks)) 125 126 if len(ready_tasks) < 1: 127 if remaining_tasks: 128 raise RuntimeError('Cycle or error in tasks.') 129 return 130 131 futures = {executor.submit(task.run): task for task in ready_tasks} 132 for future in concurrent.futures.as_completed(futures): 133 if future.exception(): 134 raise future.exception() 135 136 for t in ready_tasks: 137 remaining_tasks.remove(t) 138 139 if iteration > max_iterations: 140 raise RuntimeError('Processing take too much iterations') 141 iteration += 1