1import concurrent.futures
2from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
3from quantum_launcher.base import Algorithm
4
5
[docs]
6class Task:
7 def __init__(self, func: Callable, args: Tuple[Any] = None, kwargs: Dict[str, Any] = None, num_output: int = 1):
8 if args is None:
9 args = tuple()
10 if kwargs is None:
11 kwargs = {}
12 self.func = func
13 self.dependencies: List[Task] = [arg for arg in args if isinstance(arg, Task)]
14 self.dependencies.extend([value for value in kwargs.values() if isinstance(value, Task)])
15 self.args = args
16 self.kwargs = kwargs
17 self.done = False
18 self.result = None
19 self.num_output = num_output
20 self.subtasks: List[SubTask] = []
21
[docs]
22 def run(self):
23 binded_args = [arg.result if isinstance(arg, Task) else arg for arg in self.args]
24 binded_kwargs = {key: (value.result if isinstance(value, Task) else value) for key, value in self.kwargs.items()}
25 self.result = self.func(*binded_args, **binded_kwargs)
26 self.done = True
27
[docs]
28 def is_ready(self):
29 return all(map(lambda x: x.done, self.dependencies))
30
31 def __iter__(self):
32 for i in range(self.num_output):
33 yield SubTask(self, i)
34
35
[docs]
36class SubTask(Task):
37 def __init__(self, task: Task, index: int):
38 self.task = task
39 self.index = index
40
41 @property
42 def result(self):
43 return self.task.result[self.index]
44
45 @property
46 def done(self):
47 return self.task.done
48
49
[docs]
50class Workflow(Algorithm):
51 _algorithm_format = 'none'
52
53 def __init__(self, tasks: List[Task], input_task: Task, output_task: Task, input_format: str = 'none'):
54 self.tasks = tasks
55 self.input_task = input_task
56 self.output_task = output_task
57 self._algorithm_format = input_format
58
[docs]
59 def run(self, problem, backend, formatter):
60 input_result = formatter(problem)
61 self.input_task.result = input_result
62 with concurrent.futures.ThreadPoolExecutor() as executor:
63 _execute_workflow(self.tasks, executor)
64 return self.output_task.result
65
66
[docs]
67class WorkflowManager:
68 def __init__(self, manager: Literal['ql', 'prefect', 'airflow'] = 'ql'):
69 self.tasks: List[Task] = []
70 self.manager = manager
71 self.input_task: Optional[Task] = None
72 self.input_task_format: str = 'none'
73 self.output_task: Optional[Task] = None
74
75 def __enter__(self):
76 return self
77
78 def __exit__(self, exc_type, exc_val, exc_tb):
79 pass
80
[docs]
81 def task(self, func, args: Tuple = None, kwargs: Dict = None, num_output=None) -> Task:
82 args = args or tuple()
83 kwargs = kwargs or dict()
84 new_task = Task(func, args, kwargs, num_output=num_output)
85 self.tasks.append(new_task)
86 return new_task
87
88 def __call__(self, input_value=None, /):
89 if self.input_task:
90 self.input_task.result = input_value
91 with concurrent.futures.ThreadPoolExecutor() as executor:
92 _execute_workflow(self.tasks, executor)
93 if self.output_task:
94 return self.output_task.result
95
[docs]
96 def print_dag(self):
97 for task in self.tasks:
98 dep_names = [dep.func.__name__ for dep in task.dependencies]
99 print(f"{task.func.__name__} -> {dep_names}")
100
106
[docs]
107 def output(self, task: Task):
108 self.output_task = task
109
[docs]
110 def to_workflow(self) -> Workflow:
111 return Workflow(self.tasks, self.input_task, self.output_task, input_format=self.input_task_format)
112
113
114def _execute_workflow(tasks: List[Task], executor: concurrent.futures.Executor, max_iterations: Optional[bool] = None):
115 remaining_tasks = set(tasks)
116 max_iterations, iteration = max_iterations or len(remaining_tasks), 0
117 for _ in range(max_iterations):
118 ready_tasks = list(filter(Task.is_ready, remaining_tasks))
119
120 if len(ready_tasks) < 1:
121 if remaining_tasks:
122 raise RuntimeError("Cycle or error in tasks.")
123 return
124
125 futures = {executor.submit(task.run): task for task in ready_tasks}
126 for future in concurrent.futures.as_completed(futures):
127 if future.exception():
128 raise future.exception()
129
130 for t in ready_tasks:
131 remaining_tasks.remove(t)
132
133 if iteration > max_iterations:
134 raise RuntimeError("Processing take too much iterations")
135 iteration += 1