Coverage for src/srunx/client.py: 87%

153 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-24 15:16 +0000

1"""SLURM client for job submission and management.""" 

2 

3import subprocess 

4import tempfile 

5import time 

6from collections.abc import Sequence 

7from importlib.resources import files 

8from pathlib import Path 

9 

10from srunx.callbacks import Callback 

11from srunx.logging import get_logger 

12from srunx.models import ( 

13 BaseJob, 

14 Job, 

15 JobStatus, 

16 JobType, 

17 RunnableJobType, 

18 ShellJob, 

19 render_job_script, 

20) 

21from srunx.utils import get_job_status, job_status_msg 

22 

23logger = get_logger(__name__) 

24 

25 

26class Slurm: 

27 """Client for interacting with SLURM workload manager.""" 

28 

29 def __init__( 

30 self, 

31 default_template: str | None = None, 

32 callbacks: Sequence[Callback] | None = None, 

33 ): 

34 """Initialize SLURM client. 

35 

36 Args: 

37 default_template: Path to default job template. 

38 callbacks: List of callbacks. 

39 """ 

40 self.default_template = default_template or self._get_default_template() 

41 self.callbacks = list(callbacks) if callbacks else [] 

42 

43 def submit( 

44 self, 

45 job: RunnableJobType, 

46 template_path: str | None = None, 

47 callbacks: Sequence[Callback] | None = None, 

48 verbose: bool = False, 

49 ) -> RunnableJobType: 

50 """Submit a job to SLURM. 

51 

52 Args: 

53 job: Job configuration. 

54 template_path: Optional template path (uses default if not provided). 

55 callbacks: List of callbacks. 

56 verbose: Whether to print the rendered content. 

57 

58 Returns: 

59 Job instance with updated job_id and status. 

60 

61 Raises: 

62 subprocess.CalledProcessError: If job submission fails. 

63 """ 

64 

65 if isinstance(job, Job): 

66 template = template_path or self.default_template 

67 

68 with tempfile.TemporaryDirectory() as temp_dir: 

69 script_path = render_job_script(template, job, temp_dir, verbose) 

70 logger.debug(f"Generated SLURM script at: {script_path}") 

71 

72 # Submit job with sbatch 

73 sbatch_cmd = ["sbatch", script_path] 

74 if job.environment.sqsh: 

75 logger.debug(f"Using sqsh container: {job.environment.sqsh}") 

76 

77 logger.debug(f"Executing command: {' '.join(sbatch_cmd)}") 

78 

79 try: 

80 result = subprocess.run( 

81 sbatch_cmd, 

82 capture_output=True, 

83 text=True, 

84 check=True, 

85 ) 

86 except subprocess.CalledProcessError as e: 

87 logger.error(f"Failed to submit job '{job.name}': {e}") 

88 logger.error(f"Command: {' '.join(e.cmd)}") 

89 logger.error(f"Return code: {e.returncode}") 

90 logger.error(f"Stdout: {e.stdout}") 

91 logger.error(f"Stderr: {e.stderr}") 

92 raise 

93 

94 elif isinstance(job, ShellJob): 

95 try: 

96 result = subprocess.run( 

97 ["sbatch", job.path], 

98 capture_output=True, 

99 text=True, 

100 check=True, 

101 ) 

102 except subprocess.CalledProcessError as e: 

103 logger.error(f"Failed to submit job '{job.name}': {e}") 

104 logger.error(f"Command: {' '.join(e.cmd)}") 

105 logger.error(f"Return code: {e.returncode}") 

106 logger.error(f"Stdout: {e.stdout}") 

107 logger.error(f"Stderr: {e.stderr}") 

108 raise 

109 

110 else: 

111 raise ValueError("Either 'command' or 'file' must be set") 

112 

113 time.sleep(3) 

114 job_id = int(result.stdout.split()[-1]) 

115 job.job_id = job_id 

116 job.status = JobStatus.PENDING 

117 

118 logger.debug(f"Successfully submitted job '{job.name}' with ID {job_id}") 

119 

120 all_callbacks = self.callbacks[:] 

121 if callbacks: 

122 all_callbacks.extend(callbacks) 

123 for callback in all_callbacks: 

124 callback.on_job_submitted(job) 

125 

126 return job 

127 

128 @staticmethod 

129 def retrieve(job_id: int) -> BaseJob: 

130 """Retrieve job information from SLURM. 

131 

132 Args: 

133 job_id: SLURM job ID. 

134 

135 Returns: 

136 Job object with current status. 

137 """ 

138 return get_job_status(job_id) 

139 

140 def cancel(self, job_id: int) -> None: 

141 """Cancel a SLURM job. 

142 

143 Args: 

144 job_id: SLURM job ID to cancel. 

145 

146 Raises: 

147 subprocess.CalledProcessError: If job cancellation fails. 

148 """ 

149 logger.info(f"Cancelling job {job_id}") 

150 

151 try: 

152 subprocess.run( 

153 ["scancel", str(job_id)], 

154 check=True, 

155 ) 

156 logger.info(f"Successfully cancelled job {job_id}") 

157 except subprocess.CalledProcessError as e: 

158 logger.error(f"Failed to cancel job {job_id}: {e}") 

159 raise 

160 

161 def queue(self, user: str | None = None) -> list[BaseJob]: 

162 """List jobs for a user. 

163 

164 Args: 

165 user: Username (defaults to current user). 

166 

167 Returns: 

168 List of Job objects. 

169 """ 

170 cmd = [ 

171 "squeue", 

172 "--format", 

173 "%.18i %.9P %.15j %.8u %.8T %.10M %.9l %.6D %R", 

174 "--noheader", 

175 ] 

176 if user: 

177 cmd.extend(["--user", user]) 

178 

179 result = subprocess.run(cmd, capture_output=True, text=True, check=True) 

180 

181 jobs = [] 

182 for line in result.stdout.strip().split("\n"): 

183 if not line.strip(): 

184 continue 

185 

186 parts = line.split() 

187 if len(parts) >= 5: 

188 job_id = int(parts[0]) 

189 job_name = parts[2] 

190 status_str = parts[4] 

191 

192 try: 

193 status = JobStatus(status_str) 

194 except ValueError: 

195 status = JobStatus.PENDING # Default for unknown status 

196 

197 job = BaseJob( 

198 name=job_name, 

199 job_id=job_id, 

200 ) 

201 job.status = status 

202 jobs.append(job) 

203 

204 return jobs 

205 

206 def monitor( 

207 self, 

208 job_obj_or_id: JobType | int, 

209 poll_interval: int = 5, 

210 callbacks: Sequence[Callback] | None = None, 

211 ) -> JobType: 

212 """Wait for a job to complete. 

213 

214 Args: 

215 job_obj_or_id: Job object or job ID. 

216 poll_interval: Polling interval in seconds. 

217 callbacks: List of callbacks. 

218 

219 Returns: 

220 Completed job object. 

221 

222 Raises: 

223 RuntimeError: If job fails. 

224 """ 

225 if isinstance(job_obj_or_id, int): 

226 job = self.retrieve(job_obj_or_id) 

227 else: 

228 job = job_obj_or_id 

229 

230 all_callbacks = self.callbacks[:] 

231 if callbacks: 

232 all_callbacks.extend(callbacks) 

233 

234 msg = f"👀 {'MONITORING':<12} Job {job.name:<12} (ID: {job.job_id})" 

235 logger.info(msg) 

236 

237 previous_status = None 

238 

239 while True: 

240 job.refresh() 

241 

242 # Log status changes 

243 if job.status != previous_status: 

244 status_str = job.status.value if job.status else "Unknown" 

245 logger.debug(f"Job(name={job.name}, id={job.job_id}) is {status_str}") 

246 previous_status = job.status 

247 

248 match job.status: 

249 case JobStatus.COMPLETED: 

250 logger.info(job_status_msg(job)) 

251 for callback in all_callbacks: 

252 callback.on_job_completed(job) 

253 return job 

254 case JobStatus.FAILED: 

255 err_msg = job_status_msg(job) + "\n" 

256 if isinstance(job, Job): 

257 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

258 if log_file.exists(): 

259 with open(log_file) as f: 

260 err_msg += f.read() 

261 err_msg += f"\nLog file: {log_file}" 

262 else: 

263 err_msg += f"Log file not found: {log_file}" 

264 for callback in all_callbacks: 

265 callback.on_job_failed(job) 

266 raise RuntimeError(err_msg) 

267 case JobStatus.CANCELLED | JobStatus.TIMEOUT: 

268 err_msg = job_status_msg(job) + "\n" 

269 if isinstance(job, Job): 

270 log_file = Path(job.log_dir) / f"{job.name}_{job.job_id}.out" 

271 if log_file.exists(): 

272 with open(log_file) as f: 

273 err_msg += f.read() 

274 err_msg += f"\nLog file: {log_file}" 

275 else: 

276 err_msg += f"Log file not found: {log_file}" 

277 for callback in all_callbacks: 

278 callback.on_job_cancelled(job) 

279 raise RuntimeError(err_msg) 

280 time.sleep(poll_interval) 

281 

282 def run( 

283 self, 

284 job: RunnableJobType, 

285 template_path: str | None = None, 

286 callbacks: Sequence[Callback] | None = None, 

287 poll_interval: int = 5, 

288 verbose: bool = False, 

289 ) -> RunnableJobType: 

290 """Submit a job and wait for completion.""" 

291 submitted_job = self.submit( 

292 job, template_path=template_path, callbacks=callbacks, verbose=verbose 

293 ) 

294 monitored_job = self.monitor( 

295 submitted_job, poll_interval=poll_interval, callbacks=callbacks 

296 ) 

297 

298 # Ensure the return type matches the expected type 

299 if isinstance(monitored_job, Job | ShellJob): 

300 return monitored_job 

301 else: 

302 # This should not happen in practice, but needed for type safety 

303 return submitted_job 

304 

305 def _get_default_template(self) -> str: 

306 """Get the default job template path.""" 

307 return str(files("srunx.templates").joinpath("base.slurm.jinja")) 

308 

309 

310# Convenience functions for backward compatibility 

311def submit_job( 

312 job: RunnableJobType, 

313 template_path: str | None = None, 

314 callbacks: Sequence[Callback] | None = None, 

315 verbose: bool = False, 

316) -> RunnableJobType: 

317 """Submit a job to SLURM (convenience function). 

318 

319 Args: 

320 job: Job configuration. 

321 template_path: Optional template path (uses default if not provided). 

322 callbacks: List of callbacks. 

323 verbose: Whether to print the rendered content. 

324 """ 

325 client = Slurm() 

326 return client.submit( 

327 job, template_path=template_path, callbacks=callbacks, verbose=verbose 

328 ) 

329 

330 

331def retrieve_job(job_id: int) -> BaseJob: 

332 """Get job status (convenience function). 

333 

334 Args: 

335 job_id: SLURM job ID. 

336 """ 

337 client = Slurm() 

338 return client.retrieve(job_id) 

339 

340 

341def cancel_job(job_id: int) -> None: 

342 """Cancel a job (convenience function). 

343 

344 Args: 

345 job_id: SLURM job ID. 

346 """ 

347 client = Slurm() 

348 client.cancel(job_id)