Coverage for src/srunx/models.py: 78%

180 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-24 15:16 +0000

1"""Data models for SLURM job management.""" 

2 

3import os 

4import subprocess 

5import time 

6from enum import Enum 

7from pathlib import Path 

8from typing import Self 

9 

10import jinja2 

11from pydantic import BaseModel, Field, PrivateAttr, model_validator 

12 

13from srunx.exceptions import WorkflowValidationError 

14from srunx.logging import get_logger 

15 

16logger = get_logger(__name__) 

17 

18 

19class JobStatus(Enum): 

20 """Job status enumeration for both SLURM jobs and workflow jobs.""" 

21 

22 UNKNOWN = "UNKNOWN" 

23 PENDING = "PENDING" 

24 RUNNING = "RUNNING" 

25 COMPLETED = "COMPLETED" 

26 FAILED = "FAILED" 

27 CANCELLED = "CANCELLED" 

28 TIMEOUT = "TIMEOUT" 

29 

30 

31class JobResource(BaseModel): 

32 """SLURM resource allocation requirements.""" 

33 

34 nodes: int = Field(default=1, ge=1, description="Number of compute nodes") 

35 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node") 

36 ntasks_per_node: int = Field(default=1, ge=1, description="Number of jobs per node") 

37 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task") 

38 memory_per_node: str | None = Field( 

39 default=None, description="Memory per node (e.g., '32GB')" 

40 ) 

41 time_limit: str | None = Field( 

42 default=None, description="Time limit (e.g., '1:00:00')" 

43 ) 

44 

45 

46class JobEnvironment(BaseModel): 

47 """Job environment configuration.""" 

48 

49 conda: str | None = Field(default=None, description="Conda environment name") 

50 venv: str | None = Field(default=None, description="Virtual environment path") 

51 sqsh: str | None = Field(default=None, description="SquashFS image path") 

52 env_vars: dict[str, str] = Field( 

53 default_factory=dict, description="Environment variables" 

54 ) 

55 

56 @model_validator(mode="after") 

57 def validate_environment(self) -> Self: 

58 envs = [self.conda, self.venv, self.sqsh] 

59 non_none_count = sum(x is not None for x in envs) 

60 if non_none_count != 1: 

61 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set") 

62 return self 

63 

64 

65class BaseJob(BaseModel): 

66 name: str = Field(default="job", description="Job name") 

67 job_id: int | None = Field(default=None, description="SLURM job ID") 

68 depends_on: list[str] = Field( 

69 default_factory=list, description="Task dependencies for workflow execution" 

70 ) 

71 

72 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING) 

73 

74 @property 

75 def status(self) -> JobStatus: 

76 """ 

77 Accessing ``job.status`` always triggers a lightweight refresh 

78 (only if we have a ``job_id`` and the status isn't terminal). 

79 """ 

80 if self.job_id is not None and self._status not in { 

81 JobStatus.COMPLETED, 

82 JobStatus.FAILED, 

83 JobStatus.CANCELLED, 

84 JobStatus.TIMEOUT, 

85 }: 

86 self.refresh() 

87 return self._status 

88 

89 @status.setter 

90 def status(self, value: JobStatus) -> None: 

91 self._status = value 

92 

93 def refresh(self, retries: int = 3) -> Self: 

94 """Query sacct and update ``_status`` in-place.""" 

95 if self.job_id is None: 

96 return self 

97 

98 for retry in range(retries): 

99 try: 

100 result = subprocess.run( 

101 [ 

102 "sacct", 

103 "-j", 

104 str(self.job_id), 

105 "--format", 

106 "JobID,State", 

107 "--noheader", 

108 "--parsable2", 

109 ], 

110 capture_output=True, 

111 text=True, 

112 check=True, 

113 ) 

114 except subprocess.CalledProcessError as e: 

115 logger.error(f"Failed to query job {self.job_id}: {e}") 

116 raise 

117 

118 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else "" 

119 if not line: 

120 if retry < retries - 1: 

121 time.sleep(1) 

122 continue 

123 self._status = JobStatus.UNKNOWN 

124 return self 

125 break 

126 

127 _, state = line.split("|", 1) 

128 self._status = JobStatus(state) 

129 return self 

130 

131 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool: 

132 """All dependencies are completed & this job is still pending.""" 

133 return self.status == JobStatus.PENDING and all( 

134 dep in completed_job_names for dep in self.depends_on 

135 ) 

136 

137 

138class Job(BaseJob): 

139 """Represents a SLURM job with complete configuration.""" 

140 

141 command: list[str] = Field(description="Command to execute") 

142 resources: JobResource = Field( 

143 default_factory=JobResource, description="Resource requirements" 

144 ) 

145 environment: JobEnvironment = Field( 

146 default_factory=JobEnvironment, description="Environment setup" 

147 ) 

148 log_dir: str = Field( 

149 default=os.getenv("SLURM_LOG_DIR", "logs"), 

150 description="Directory for log files", 

151 ) 

152 work_dir: str = Field(default_factory=os.getcwd, description="Working directory") 

153 

154 

155class ShellJob(BaseJob): 

156 path: str = Field(description="Shell script path to execute") 

157 

158 

159type JobType = BaseJob | Job | ShellJob 

160type RunnableJobType = Job | ShellJob 

161 

162 

163class Workflow: 

164 """Represents a workflow containing multiple jobs with dependencies.""" 

165 

166 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None: 

167 if jobs is None: 

168 jobs = [] 

169 

170 self.name = name 

171 self.jobs = jobs 

172 

173 def add(self, job: RunnableJobType) -> None: 

174 # Check if job already exists 

175 if job.depends_on: 

176 for dep in job.depends_on: 

177 if dep not in self.jobs: 

178 raise WorkflowValidationError( 

179 f"Job '{job.name}' depends on unknown job '{dep}'" 

180 ) 

181 self.jobs.append(job) 

182 

183 def remove(self, job: RunnableJobType) -> None: 

184 self.jobs.remove(job) 

185 

186 def get(self, name: str) -> RunnableJobType | None: 

187 """Get a job by name.""" 

188 for job in self.jobs: 

189 if job.name == name: 

190 return job.refresh() 

191 return None 

192 

193 def get_dependencies(self, job_name: str) -> list[str]: 

194 """Get dependencies for a specific job.""" 

195 job = self.get(job_name) 

196 return job.depends_on if job else [] 

197 

198 def show(self): 

199 msg = f"""\ 

200{" PLAN ":=^80} 

201Workflow: {self.name} 

202Jobs: {len(self.jobs)} 

203""" 

204 

205 def add_indent(indent: int, msg: str) -> str: 

206 return " " * indent + msg 

207 

208 for job in self.jobs: 

209 msg += add_indent(1, f"Job: {job.name}\n") 

210 if isinstance(job, Job): 

211 msg += add_indent( 

212 2, f"{'Command:': <13} {' '.join(job.command or [])}\n" 

213 ) 

214 msg += add_indent( 

215 2, 

216 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n", 

217 ) 

218 if job.environment.conda: 

219 msg += add_indent( 

220 2, f"{'Conda env:': <13} {job.environment.conda}\n" 

221 ) 

222 if job.environment.sqsh: 

223 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n") 

224 if job.environment.venv: 

225 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n") 

226 elif isinstance(job, ShellJob): 

227 msg += add_indent(2, f"{'Path:': <13} {job.path}\n") 

228 if job.depends_on: 

229 msg += add_indent( 

230 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n" 

231 ) 

232 

233 msg += f"{'=' * 80}\n" 

234 print(msg) 

235 

236 def validate(self): 

237 """Validate workflow job dependencies.""" 

238 job_names = {job.name for job in self.jobs} 

239 

240 if len(job_names) != len(self.jobs): 

241 raise WorkflowValidationError("Duplicate job names found in workflow") 

242 

243 for job in self.jobs: 

244 for dependency in job.depends_on: 

245 if dependency not in job_names: 

246 raise WorkflowValidationError( 

247 f"Job '{job.name}' depends on unknown job '{dependency}'" 

248 ) 

249 

250 # Check for circular dependencies (simple check) 

251 visited = set() 

252 rec_stack = set() 

253 

254 def has_cycle(job_name: str) -> bool: 

255 if job_name in rec_stack: 

256 return True 

257 if job_name in visited: 

258 return False 

259 

260 visited.add(job_name) 

261 rec_stack.add(job_name) 

262 

263 job = self.get(job_name) 

264 if job: 

265 for dependency in job.depends_on: 

266 if has_cycle(dependency): 

267 return True 

268 

269 rec_stack.remove(job_name) 

270 return False 

271 

272 for job in self.jobs: 

273 if has_cycle(job.name): 

274 raise WorkflowValidationError( 

275 f"Circular dependency detected involving job '{job.name}'" 

276 ) 

277 

278 

279def render_job_script( 

280 template_path: Path | str, 

281 job: Job, 

282 output_dir: Path | str, 

283 verbose: bool = False, 

284) -> str: 

285 """Render a SLURM job script from a template. 

286 

287 Args: 

288 template_path: Path to the Jinja template file. 

289 job: Job configuration. 

290 output_dir: Directory where the generated script will be saved. 

291 verbose: Whether to print the rendered content. 

292 

293 Returns: 

294 Path to the generated SLURM batch script. 

295 

296 Raises: 

297 FileNotFoundError: If the template file does not exist. 

298 jinja2.TemplateError: If template rendering fails. 

299 """ 

300 template_file = Path(template_path) 

301 if not template_file.is_file(): 

302 raise FileNotFoundError(f"Template file '{template_path}' not found") 

303 

304 with open(template_file, encoding="utf-8") as f: 

305 template_content = f.read() 

306 

307 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined) 

308 

309 # Prepare template variables 

310 template_vars = { 

311 "job_name": job.name, 

312 "command": " ".join(job.command or []), 

313 "log_dir": job.log_dir, 

314 "work_dir": job.work_dir, 

315 "environment_setup": _build_environment_setup(job.environment), 

316 **job.resources.model_dump(), 

317 } 

318 

319 rendered_content = template.render(template_vars) 

320 

321 if verbose: 

322 print(rendered_content) 

323 

324 # Generate output file 

325 output_path = Path(output_dir) / f"{job.name}.slurm" 

326 with open(output_path, "w", encoding="utf-8") as f: 

327 f.write(rendered_content) 

328 

329 return str(output_path) 

330 

331 

332def _build_environment_setup(environment: JobEnvironment) -> str: 

333 """Build environment setup script.""" 

334 setup_lines = [] 

335 

336 # Set environment variables 

337 for key, value in environment.env_vars.items(): 

338 setup_lines.append(f"export {key}={value}") 

339 

340 # Activate environments 

341 if environment.conda: 

342 home_dir = Path.home() 

343 setup_lines.extend( 

344 [ 

345 f"source {str(home_dir)}/miniconda3/bin/activate", 

346 "conda deactivate", 

347 f"conda activate {environment.conda}", 

348 ] 

349 ) 

350 elif environment.venv: 

351 setup_lines.append(f"source {environment.venv}/bin/activate") 

352 elif environment.sqsh: 

353 setup_lines.extend( 

354 [ 

355 f': "${{IMAGE:={environment.sqsh}}}"', 

356 "declare -a CONTAINER_ARGS=(", 

357 ' --container-image "$IMAGE"', 

358 ")", 

359 ] 

360 ) 

361 

362 return "\n".join(setup_lines)