1import contextlib
2import os
3import pickle
4import shutil
5import subprocess
6import sys
7import time
8from collections.abc import Callable
9from pathlib import Path
10from typing import TYPE_CHECKING, Any
11
12from qlauncher.exceptions import DependencyError
13from qlauncher.workflow.base_job_manager import BaseJobManager
14
15if TYPE_CHECKING:
16 from qlauncher.base import Result
17
18try:
19 import dill
20except ImportError as e:
21 raise DependencyError(e, install_hint='pilotjob') from e
22
23
[docs]
24class SlurmJobManager(BaseJobManager):
25 def __init__(
26 self,
27 sbatch_exe: str = 'sbatch',
28 scancel_exe: str = 'scancel',
29 slurm_options: dict[str, Any] | None = None,
30 env_setup: list[str] | None = None,
31 ) -> None:
32 """
33 Job manager that submits QLauncher jobs to Slurm via ``sbatch``.
34
35 Args:
36 sbatch_exe (str, optional): Name or path of the ``sbatch`` executable
37 used to submit jobs to Slurm. Defaults to ``"sbatch"``.
38 slurm_options (dict[str, Any] | None, optional): Mapping of Slurm
39 options to their values (e.g. ``{"time": "00:02:00"}``).
40 Keys are used as option names after ``--`` in the generated
41 ``#SBATCH`` lines. Defaults to an empty dict.
42 env_setup (list[str] | None, optional): List of shell commands that
43 will be written into the Slurm script before the ``srun`` line,
44 e.g. module loads or virtual environment activation commands.
45 Defaults to an empty list.
46
47 Raises:
48 DependencyError: If the ``sbatch_exe`` executable cannot be found
49 in ``PATH``.
50 """
51 super().__init__()
52 self.code_path = Path(__file__).with_name('subprocess_fn.py')
53 self.sbatch_exe = sbatch_exe
54 self.scancel_exe = scancel_exe
55 self.slurm_options = slurm_options or {}
56 self.env_setup = env_setup or []
57
58 if shutil.which(self.sbatch_exe) is None:
59 raise DependencyError(
60 ImportError(f'{self.sbatch_exe} not found in PATH'),
61 install_hint='slurm',
62 )
63
64 if shutil.which(self.scancel_exe) is None:
65 raise DependencyError(
66 ImportError(f'{self.scancel_exe} not found in PATH'),
67 install_hint='slurm',
68 )
69
[docs]
70 def submit(
71 self,
72 function,
73 cores: int = 1,
74 **kwargs,
75 ) -> str:
76 """
77 Creates a :class:`QLauncher`
78 instance from ``problem``, ``algorithm`` and ``backend`` and forwards
79 it to :meth:`submit_launcher`.
80
81 Args:
82 problem (Problem | Model): Problem to be solved.
83 algorithm (Algorithm): Algorithm to be used.
84 backend (Backend): Backend on which the algorithm will be executed.
85 cores (int, optional): Number of CPU cores per task requested from
86 Slurm (mapped to ``--cpus-per-task``). Defaults to 1.
87
88 Returns:
89 str: Slurm job ID returned by ``sbatch``.
90
91 Raises:
92 RuntimeError: If ``sbatch`` returns a non-zero exit code.
93 """
94
95 job_uid = self._make_job_uid()
96
97 input_file = f'input.{job_uid}.pkl'
98 output_file = f'output.{job_uid}.pkl'
99 script_path = f'slurm_job.{job_uid}.sh'
100
101 with open(input_file, 'wb') as f:
102 dill.dump(function, f)
103
104 self._write_sbatch_script(
105 script_path=script_path,
106 job_uid=job_uid,
107 input_file=input_file,
108 output_file=output_file,
109 cores=cores,
110 )
111
112 res = subprocess.run(
113 [self.sbatch_exe, script_path],
114 capture_output=True,
115 text=True,
116 check=False,
117 )
118
119 if res.returncode != 0:
120 raise RuntimeError(f'sbatch failed ({res.returncode}): {res.stderr}')
121
122 job_id = res.stdout.strip().split()[-1]
123
124 self.jobs[job_id] = {
125 'uid': job_uid,
126 'input_file': input_file,
127 'output_file': output_file,
128 'script_path': script_path,
129 'finished': False,
130 }
131 return job_id
132
[docs]
133 def wait_for_a_job(
134 self,
135 job_id: str | None = None,
136 timeout: float | None = None,
137 ):
138 """
139 Waits until a Slurm job finishes and returns its ID.
140
141 Args:
142 job_id (str | None, optional): ID of the job to wait for. If
143 ``None``, the first job in :attr:`jobs` that is not yet marked
144 as finished is selected. Defaults to ``None``.
145 timeout (float | None, optional): Maximum time to wait in seconds.
146 If ``None``, wait indefinitely. Defaults to ``None``.
147
148 Raises:
149 ValueError: If ``job_id`` is ``None`` and there are no jobs left.
150 TimeoutError: If the timeout is exceeded before the job finishes.
151 RuntimeError: If the job disappears from ``squeue`` without
152 producing a result file, or if it finishes in a non-successful
153 state.
154
155 Returns:
156 str: ID of the finished job.
157 """
158 if job_id is None:
159 not_finished = [jid for jid, j in self.jobs.items() if not j['finished']]
160 if not not_finished:
161 raise ValueError('There are no jobs left')
162 job_id = not_finished[0]
163
164 job = self.jobs[job_id]
165 if job.get('canceled', False):
166 job['finished'] = True
167 return job_id
168
169 output_file = job['output_file']
170
171 start = time.time()
172
173 while True:
174 now = time.time()
175 if timeout is not None and (now - start) > timeout:
176 raise TimeoutError(f'Timeout waiting for job {job_id}')
177
178 state = self._get_slurm_state(job_id)
179
180 if state is None:
181 if Path(output_file).exists():
182 job['finished'] = True
183 return job_id
184 raise RuntimeError(f'Job {job_id} disappeared from squeue but result file does not exist: {output_file}')
185
186 if state in ('PENDING', 'CONFIGURING', 'RUNNING', 'COMPLETING'):
187 time.sleep(2.0)
188 continue
189
190 if state in ('COMPLETED', 'CG'):
191 if not Path(output_file).exists():
192 raise RuntimeError(f'Job {job_id} finished with state {state}, but result file not found: {output_file}')
193 job['finished'] = True
194 return job_id
195
196 if state in ('CANCELLED', 'CANCELED'):
197 job['canceled'] = True
198 job['finished'] = True
199 return job_id
200
201 raise RuntimeError(f'Job {job_id} finished in bad state: {state}')
202
[docs]
203 def read_results(self, job_id):
204 """
205 Reads the result of a finished job from its output file.
206
207 Args:
208 job_id (str): Slurm job ID returned by :meth:`submit` or
209 :meth:`submit_launcher`.
210
211 Raises:
212 KeyError: If ``job_id`` is not known to this manager.
213 FileNotFoundError: If the expected output file does not exist.
214
215 Returns:
216 Result: Deserialized result object produced by the worker process.
217 """
218 if job_id not in self.jobs:
219 raise KeyError(f'Job {job_id} not found')
220
221 job = self.jobs[job_id]
222 if job.get('canceled', False):
223 raise RuntimeError(f'Job {job_id} was canceled; no results are available')
224
225 output_file = job['output_file']
226
227 if not Path(output_file).exists():
228 raise FileNotFoundError(f'Result file for job {job_id} not found: {output_file}')
229
230 with open(output_file, 'rb') as rt:
231 result: Result = pickle.load(rt)
232
233 job['finished'] = True
234 return result
235
[docs]
236 def cancel(self, job_id: str) -> None:
237 """
238 Cancel a given Slurm job via scancel.
239
240 Args:
241 job_id: Slurm job id returned by submit().
242
243 Raises:
244 KeyError: If job_id is not known to this manager.
245 RuntimeError: If scancel fails.
246 """
247 if job_id not in self.jobs:
248 raise KeyError(f'Job {job_id} not found')
249
250 res = subprocess.run(
251 [self.scancel_exe, job_id],
252 capture_output=True,
253 text=True,
254 check=False,
255 )
256
257 if res.returncode != 0:
258 raise RuntimeError(f'scancel failed ({res.returncode}): {res.stderr.strip() or res.stdout.strip()}')
259
260 job = self.jobs[job_id]
261 job['canceled'] = True
262 job['finished'] = True
263
[docs]
264 def clean_up(self):
265 """
266 Removes temporary files created for all tracked jobs.
267 """
268 for job in self.jobs.values():
269 for key in ('script_path', 'input_file'):
270 path = job.get(key)
271 if path and Path(path).exists():
272 with contextlib.suppress(OSError):
273 os.remove(path)
274
[docs]
275 def run(self, function: Callable[..., Any], cores: int = 1, **kwargs) -> Any:
276 return super().run(function, cores=cores, **kwargs)
277
278 def _write_sbatch_script(
279 self,
280 script_path: str,
281 job_uid: str,
282 input_file: str,
283 output_file: str,
284 cores: int,
285 ) -> None:
286 opts = dict(self.slurm_options)
287 opts.setdefault('ntasks', 1)
288 opts.setdefault('cpus-per-task', cores)
289
290 stdout_file = f'stdout.{job_uid}'
291 stderr_file = f'stderr.{job_uid}'
292
293 with open(script_path, 'w', encoding='utf-8') as sh:
294 sh.write('#!/bin/bash\n')
295 sh.write(f'#SBATCH --job-name=ql_{job_uid}\n')
296 sh.write(f'#SBATCH --output={stdout_file}\n')
297 sh.write(f'#SBATCH --error={stderr_file}\n')
298
299 for opt, val in opts.items():
300 sh.write(f'#SBATCH --{opt}={val}\n')
301
302 for line in self.env_setup:
303 sh.write(line + '\n')
304
305 sh.write(f'srun {sys.executable} {self.code_path} {input_file} {output_file}\n')
306
307 @staticmethod
308 def _get_slurm_state(job_id: str) -> str | None:
309 try:
310 res = subprocess.run(
311 ['squeue', '-h', '-j', job_id, '-o', '%T'],
312 capture_output=True,
313 text=True,
314 check=False,
315 )
316 except FileNotFoundError:
317 return None
318
319 if res.returncode != 0:
320 return None
321
322 out = res.stdout.strip()
323 if not out:
324 return None
325
326 return out.splitlines()[0].strip()