1import concurrent.futures
2from collections.abc import Callable
3from typing import Any, Literal
4
5from qlauncher.base import Algorithm, Backend, Model, Problem
6
7
[docs]
8class Task:
9 def __init__(self, func: Callable, args: tuple[Any] | None = None, kwargs: dict[str, Any] | None = None, num_output: int = 1):
10 if args is None:
11 args = tuple()
12 if kwargs is None:
13 kwargs = {}
14 self.func = func
15 self.dependencies: list[Task] = [arg for arg in args if isinstance(arg, Task)]
16 self.dependencies.extend([value for value in kwargs.values() if isinstance(value, Task)])
17 self.args = args
18 self.kwargs = kwargs
19 self.done = False
20 self.result = None
21 self.num_output = num_output
22 self.subtasks: list[SubTask] = []
23
[docs]
24 def run(self) -> None:
25 binded_args = [arg.result if isinstance(arg, Task) else arg for arg in self.args]
26 binded_kwargs = {key: (value.result if isinstance(value, Task) else value) for key, value in self.kwargs.items()}
27 self.result = self.func(*binded_args, **binded_kwargs)
28 self.done = True
29
[docs]
30 def is_ready(self):
31 return all(map(lambda x: x.done, self.dependencies))
32
33 def __iter__(self):
34 for i in range(self.num_output):
35 yield SubTask(self, i)
36
37
[docs]
38class SubTask(Task):
39 def __init__(self, task: Task, index: int):
40 self.task = task
41 self.index = index
42
43 @property
44 def result(self):
45 return self.task.result[self.index]
46
47 @property
48 def done(self):
49 return self.task.done
50
51
[docs]
52class Workflow(Algorithm):
53 def __init__(self, tasks: list[Task], input_task: Task, output_task: Task, input_format: type[Problem | Model]):
54 self.tasks = tasks
55 self.input_task = input_task
56 self.output_task = output_task
57 self.input_format = input_format
58
[docs]
59 def run(self, problem: Algorithm, backend: Backend) -> Any: # noqa: ANN401
60 self.input_task.result = problem
61 with concurrent.futures.ThreadPoolExecutor() as executor:
62 _execute_workflow(self.tasks, executor)
63 return self.output_task.result
64
69
70
[docs]
71class WorkflowManager:
72 def __init__(self, manager: Literal['ql', 'prefect', 'airflow'] = 'ql'):
73 self.tasks: list[Task] = []
74 self.manager = manager
75 self.input_task: Task | None = None
76 self.input_task_format: type[Problem | Model] = Model
77 self.output_task: Task | None = None
78
79 def __enter__(self):
80 return self
81
82 def __exit__(self, exc_type, exc_val, exc_tb):
83 pass
84
[docs]
85 def task(self, func, args: tuple | None = None, kwargs: dict | None = None, num_output=None) -> Task:
86 args = args or tuple()
87 kwargs = kwargs or dict()
88 new_task = Task(func, args, kwargs, num_output=num_output)
89 self.tasks.append(new_task)
90 return new_task
91
92 def __call__(self, input_value=None, /):
93 if self.input_task:
94 self.input_task.result = input_value
95 with concurrent.futures.ThreadPoolExecutor() as executor:
96 _execute_workflow(self.tasks, executor)
97 if self.output_task:
98 return self.output_task.result
99 return None
100
[docs]
101 def print_dag(self) -> None:
102 for task in self.tasks:
103 dep_names = [dep.func.__name__ for dep in task.dependencies]
104 print(f'{task.func.__name__} -> {dep_names}')
105
111
[docs]
112 def output(self, task: Task) -> None:
113 self.output_task = task
114
[docs]
115 def to_workflow(self) -> Workflow:
116 return Workflow(self.tasks, self.input_task, self.output_task, input_format=self.input_task_format)
117
118
119def _execute_workflow(tasks: list[Task], executor: concurrent.futures.Executor, max_iterations: int | None = None) -> None:
120 remaining_tasks = set(tasks)
121 max_iterations: int = max_iterations or len(remaining_tasks)
122 iteration = 0
123 for _ in range(max_iterations):
124 ready_tasks = list(filter(Task.is_ready, remaining_tasks))
125
126 if len(ready_tasks) < 1:
127 if remaining_tasks:
128 raise RuntimeError('Cycle or error in tasks.')
129 return
130
131 futures = {executor.submit(task.run): task for task in ready_tasks}
132 for future in concurrent.futures.as_completed(futures):
133 if future.exception():
134 raise future.exception()
135
136 for t in ready_tasks:
137 remaining_tasks.remove(t)
138
139 if iteration > max_iterations:
140 raise RuntimeError('Processing take too much iterations')
141 iteration += 1