1import concurrent.futures
2from typing import Any, Literal
3from collections.abc import Callable
4from qlauncher.base import Algorithm
5
6
[docs]
7class Task:
8 def __init__(self, func: Callable, args: tuple[Any] | None = None, kwargs: dict[str, Any] | None = None, num_output: int = 1):
9 if args is None:
10 args = tuple()
11 if kwargs is None:
12 kwargs = {}
13 self.func = func
14 self.dependencies: list[Task] = [arg for arg in args if isinstance(arg, Task)]
15 self.dependencies.extend([value for value in kwargs.values() if isinstance(value, Task)])
16 self.args = args
17 self.kwargs = kwargs
18 self.done = False
19 self.result = None
20 self.num_output = num_output
21 self.subtasks: list[SubTask] = []
22
[docs]
23 def run(self):
24 binded_args = [arg.result if isinstance(arg, Task) else arg for arg in self.args]
25 binded_kwargs = {key: (value.result if isinstance(value, Task) else value) for key, value in self.kwargs.items()}
26 self.result = self.func(*binded_args, **binded_kwargs)
27 self.done = True
28
[docs]
29 def is_ready(self):
30 return all(map(lambda x: x.done, self.dependencies))
31
32 def __iter__(self):
33 for i in range(self.num_output):
34 yield SubTask(self, i)
35
36
[docs]
37class SubTask(Task):
38 def __init__(self, task: Task, index: int):
39 self.task = task
40 self.index = index
41
42 @property
43 def result(self):
44 return self.task.result[self.index]
45
46 @property
47 def done(self):
48 return self.task.done
49
50
[docs]
51class Workflow(Algorithm):
52 _algorithm_format = 'none'
53
54 def __init__(self, tasks: list[Task], input_task: Task, output_task: Task, input_format: str = 'none'):
55 self.tasks = tasks
56 self.input_task = input_task
57 self.output_task = output_task
58 self._algorithm_format = input_format
59
[docs]
60 def run(self, problem, backend, formatter):
61 input_result = formatter(problem)
62 self.input_task.result = input_result
63 with concurrent.futures.ThreadPoolExecutor() as executor:
64 _execute_workflow(self.tasks, executor)
65 return self.output_task.result
66
67
[docs]
68class WorkflowManager:
69 def __init__(self, manager: Literal['ql', 'prefect', 'airflow'] = 'ql'):
70 self.tasks: list[Task] = []
71 self.manager = manager
72 self.input_task: Task | None = None
73 self.input_task_format: str = 'none'
74 self.output_task: Task | None = None
75
76 def __enter__(self):
77 return self
78
79 def __exit__(self, exc_type, exc_val, exc_tb):
80 pass
81
[docs]
82 def task(self, func, args: tuple | None = None, kwargs: dict | None = None, num_output=None) -> Task:
83 args = args or tuple()
84 kwargs = kwargs or dict()
85 new_task = Task(func, args, kwargs, num_output=num_output)
86 self.tasks.append(new_task)
87 return new_task
88
89 def __call__(self, input_value=None, /):
90 if self.input_task:
91 self.input_task.result = input_value
92 with concurrent.futures.ThreadPoolExecutor() as executor:
93 _execute_workflow(self.tasks, executor)
94 if self.output_task:
95 return self.output_task.result
96
[docs]
97 def print_dag(self):
98 for task in self.tasks:
99 dep_names = [dep.func.__name__ for dep in task.dependencies]
100 print(f"{task.func.__name__} -> {dep_names}")
101
107
[docs]
108 def output(self, task: Task):
109 self.output_task = task
110
[docs]
111 def to_workflow(self) -> Workflow:
112 return Workflow(self.tasks, self.input_task, self.output_task, input_format=self.input_task_format)
113
114
115def _execute_workflow(tasks: list[Task], executor: concurrent.futures.Executor, max_iterations: int | None = None):
116 remaining_tasks = set(tasks)
117 max_iterations: int = max_iterations or len(remaining_tasks)
118 iteration = 0
119 for _ in range(max_iterations):
120 ready_tasks = list(filter(Task.is_ready, remaining_tasks))
121
122 if len(ready_tasks) < 1:
123 if remaining_tasks:
124 raise RuntimeError("Cycle or error in tasks.")
125 return
126
127 futures = {executor.submit(task.run): task for task in ready_tasks}
128 for future in concurrent.futures.as_completed(futures):
129 if future.exception():
130 raise future.exception()
131
132 for t in ready_tasks:
133 remaining_tasks.remove(t)
134
135 if iteration > max_iterations:
136 raise RuntimeError("Processing take too much iterations")
137 iteration += 1