Coverage for src/srunx/client.py: 87%
153 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
1"""SLURM client for job submission and management."""
3import subprocess
4import tempfile
5import time
6from collections.abc import Sequence
7from importlib.resources import files
8from pathlib import Path
10from srunx.callbacks import Callback
11from srunx.logging import get_logger
12from srunx.models import (
13 BaseJob,
14 Job,
15 JobStatus,
16 JobType,
17 RunnableJobType,
18 ShellJob,
19 render_job_script,
20)
21from srunx.utils import get_job_status, job_status_msg
23logger = get_logger(__name__)
26class Slurm:
27 """Client for interacting with SLURM workload manager."""
29 def __init__(
30 self,
31 default_template: str | None = None,
32 callbacks: Sequence[Callback] | None = None,
33 ):
34 """Initialize SLURM client.
36 Args:
37 default_template: Path to default job template.
38 callbacks: List of callbacks.
39 """
40 self.default_template = default_template or self._get_default_template()
41 self.callbacks = list(callbacks) if callbacks else []
43 def submit(
44 self,
45 job: RunnableJobType,
46 template_path: str | None = None,
47 callbacks: Sequence[Callback] | None = None,
48 verbose: bool = False,
49 ) -> RunnableJobType:
50 """Submit a job to SLURM.
52 Args:
53 job: Job configuration.
54 template_path: Optional template path (uses default if not provided).
55 callbacks: List of callbacks.
56 verbose: Whether to print the rendered content.
58 Returns:
59 Job instance with updated job_id and status.
61 Raises:
62 subprocess.CalledProcessError: If job submission fails.
63 """
65 if isinstance(job, Job):
66 template = template_path or self.default_template
68 with tempfile.TemporaryDirectory() as temp_dir:
69 script_path = render_job_script(template, job, temp_dir, verbose)
70 logger.debug(f"Generated SLURM script at: {script_path}")
72 # Submit job with sbatch
73 sbatch_cmd = ["sbatch", script_path]
74 if job.environment.sqsh:
75 logger.debug(f"Using sqsh container: {job.environment.sqsh}")
77 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}")
79 try:
80 result = subprocess.run(
81 sbatch_cmd,
82 capture_output=True,
83 text=True,
84 check=True,
85 )
86 except subprocess.CalledProcessError as e:
87 logger.error(f"Failed to submit job '{job.name}': {e}")
88 logger.error(f"Command: {' '.join(e.cmd)}")
89 logger.error(f"Return code: {e.returncode}")
90 logger.error(f"Stdout: {e.stdout}")
91 logger.error(f"Stderr: {e.stderr}")
92 raise
94 elif isinstance(job, ShellJob):
95 try:
96 result = subprocess.run(
97 ["sbatch", job.path],
98 capture_output=True,
99 text=True,
100 check=True,
101 )
102 except subprocess.CalledProcessError as e:
103 logger.error(f"Failed to submit job '{job.name}': {e}")
104 logger.error(f"Command: {' '.join(e.cmd)}")
105 logger.error(f"Return code: {e.returncode}")
106 logger.error(f"Stdout: {e.stdout}")
107 logger.error(f"Stderr: {e.stderr}")
108 raise
110 else:
111 raise ValueError("Either 'command' or 'file' must be set")
113 time.sleep(3)
114 job_id = int(result.stdout.split()[-1])
115 job.job_id = job_id
116 job.status = JobStatus.PENDING
118 logger.debug(f"Successfully submitted job '{job.name}' with ID {job_id}")
120 all_callbacks = self.callbacks[:]
121 if callbacks:
122 all_callbacks.extend(callbacks)
123 for callback in all_callbacks:
124 callback.on_job_submitted(job)
126 return job
128 @staticmethod
129 def retrieve(job_id: int) -> BaseJob:
130 """Retrieve job information from SLURM.
132 Args:
133 job_id: SLURM job ID.
135 Returns:
136 Job object with current status.
137 """
138 return get_job_status(job_id)
140 def cancel(self, job_id: int) -> None:
141 """Cancel a SLURM job.
143 Args:
144 job_id: SLURM job ID to cancel.
146 Raises:
147 subprocess.CalledProcessError: If job cancellation fails.
148 """
149 logger.info(f"Cancelling job {job_id}")
151 try:
152 subprocess.run(
153 ["scancel", str(job_id)],
154 check=True,
155 )
156 logger.info(f"Successfully cancelled job {job_id}")
157 except subprocess.CalledProcessError as e:
158 logger.error(f"Failed to cancel job {job_id}: {e}")
159 raise
161 def queue(self, user: str | None = None) -> list[BaseJob]:
162 """List jobs for a user.
164 Args:
165 user: Username (defaults to current user).
167 Returns:
168 List of Job objects.
169 """
170 cmd = [
171 "squeue",
172 "--format",
173 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R",
174 "--noheader",
175 ]
176 if user:
177 cmd.extend(["--user", user])
179 result = subprocess.run(cmd, capture_output=True, text=True, check=True)
181 jobs = []
182 for line in result.stdout.strip().split("\n"):
183 if not line.strip():
184 continue
186 parts = line.split()
187 if len(parts) >= 5:
188 job_id = int(parts[0])
189 job_name = parts[2]
190 status_str = parts[4]
192 try:
193 status = JobStatus(status_str)
194 except ValueError:
195 status = JobStatus.PENDING # Default for unknown status
197 job = BaseJob(
198 name=job_name,
199 job_id=job_id,
200 )
201 job.status = status
202 jobs.append(job)
204 return jobs
206 def monitor(
207 self,
208 job_obj_or_id: JobType | int,
209 poll_interval: int = 5,
210 callbacks: Sequence[Callback] | None = None,
211 ) -> JobType:
212 """Wait for a job to complete.
214 Args:
215 job_obj_or_id: Job object or job ID.
216 poll_interval: Polling interval in seconds.
217 callbacks: List of callbacks.
219 Returns:
220 Completed job object.
222 Raises:
223 RuntimeError: If job fails.
224 """
225 if isinstance(job_obj_or_id, int):
226 job = self.retrieve(job_obj_or_id)
227 else:
228 job = job_obj_or_id
230 all_callbacks = self.callbacks[:]
231 if callbacks:
232 all_callbacks.extend(callbacks)
234 msg = f"👀 {'MONITORING':<12} Job {job.name:<12} (ID: {job.job_id})"
235 logger.info(msg)
237 previous_status = None
239 while True:
240 job.refresh()
242 # Log status changes
243 if job.status != previous_status:
244 status_str = job.status.value if job.status else "Unknown"
245 logger.debug(f"Job(name={job.name}, id={job.job_id}) is {status_str}")
246 previous_status = job.status
248 match job.status:
249 case JobStatus.COMPLETED:
250 logger.info(job_status_msg(job))
251 for callback in all_callbacks:
252 callback.on_job_completed(job)
253 return job
254 case JobStatus.FAILED:
255 err_msg = job_status_msg(job) + "\n"
256 if isinstance(job, Job):
257 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
258 if log_file.exists():
259 with open(log_file) as f:
260 err_msg += f.read()
261 err_msg += f"\nLog file: {log_file}"
262 else:
263 err_msg += f"Log file not found: {log_file}"
264 for callback in all_callbacks:
265 callback.on_job_failed(job)
266 raise RuntimeError(err_msg)
267 case JobStatus.CANCELLED | JobStatus.TIMEOUT:
268 err_msg = job_status_msg(job) + "\n"
269 if isinstance(job, Job):
270 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out"
271 if log_file.exists():
272 with open(log_file) as f:
273 err_msg += f.read()
274 err_msg += f"\nLog file: {log_file}"
275 else:
276 err_msg += f"Log file not found: {log_file}"
277 for callback in all_callbacks:
278 callback.on_job_cancelled(job)
279 raise RuntimeError(err_msg)
280 time.sleep(poll_interval)
282 def run(
283 self,
284 job: RunnableJobType,
285 template_path: str | None = None,
286 callbacks: Sequence[Callback] | None = None,
287 poll_interval: int = 5,
288 verbose: bool = False,
289 ) -> RunnableJobType:
290 """Submit a job and wait for completion."""
291 submitted_job = self.submit(
292 job, template_path=template_path, callbacks=callbacks, verbose=verbose
293 )
294 monitored_job = self.monitor(
295 submitted_job, poll_interval=poll_interval, callbacks=callbacks
296 )
298 # Ensure the return type matches the expected type
299 if isinstance(monitored_job, Job | ShellJob):
300 return monitored_job
301 else:
302 # This should not happen in practice, but needed for type safety
303 return submitted_job
305 def _get_default_template(self) -> str:
306 """Get the default job template path."""
307 return str(files("srunx.templates").joinpath("base.slurm.jinja"))
310# Convenience functions for backward compatibility
311def submit_job(
312 job: RunnableJobType,
313 template_path: str | None = None,
314 callbacks: Sequence[Callback] | None = None,
315 verbose: bool = False,
316) -> RunnableJobType:
317 """Submit a job to SLURM (convenience function).
319 Args:
320 job: Job configuration.
321 template_path: Optional template path (uses default if not provided).
322 callbacks: List of callbacks.
323 verbose: Whether to print the rendered content.
324 """
325 client = Slurm()
326 return client.submit(
327 job, template_path=template_path, callbacks=callbacks, verbose=verbose
328 )
331def retrieve_job(job_id: int) -> BaseJob:
332 """Get job status (convenience function).
334 Args:
335 job_id: SLURM job ID.
336 """
337 client = Slurm()
338 return client.retrieve(job_id)
341def cancel_job(job_id: int) -> None:
342 """Cancel a job (convenience function).
344 Args:
345 job_id: SLURM job ID.
346 """
347 client = Slurm()
348 client.cancel(job_id)