Coverage for src/srunx/models.py: 78%
180 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
1"""Data models for SLURM job management."""
3import os
4import subprocess
5import time
6from enum import Enum
7from pathlib import Path
8from typing import Self
10import jinja2
11from pydantic import BaseModel, Field, PrivateAttr, model_validator
13from srunx.exceptions import WorkflowValidationError
14from srunx.logging import get_logger
16logger = get_logger(__name__)
19class JobStatus(Enum):
20 """Job status enumeration for both SLURM jobs and workflow jobs."""
22 UNKNOWN = "UNKNOWN"
23 PENDING = "PENDING"
24 RUNNING = "RUNNING"
25 COMPLETED = "COMPLETED"
26 FAILED = "FAILED"
27 CANCELLED = "CANCELLED"
28 TIMEOUT = "TIMEOUT"
31class JobResource(BaseModel):
32 """SLURM resource allocation requirements."""
34 nodes: int = Field(default=1, ge=1, description="Number of compute nodes")
35 gpus_per_node: int = Field(default=0, ge=0, description="Number of GPUs per node")
36 ntasks_per_node: int = Field(default=1, ge=1, description="Number of jobs per node")
37 cpus_per_task: int = Field(default=1, ge=1, description="Number of CPUs per task")
38 memory_per_node: str | None = Field(
39 default=None, description="Memory per node (e.g., '32GB')"
40 )
41 time_limit: str | None = Field(
42 default=None, description="Time limit (e.g., '1:00:00')"
43 )
46class JobEnvironment(BaseModel):
47 """Job environment configuration."""
49 conda: str | None = Field(default=None, description="Conda environment name")
50 venv: str | None = Field(default=None, description="Virtual environment path")
51 sqsh: str | None = Field(default=None, description="SquashFS image path")
52 env_vars: dict[str, str] = Field(
53 default_factory=dict, description="Environment variables"
54 )
56 @model_validator(mode="after")
57 def validate_environment(self) -> Self:
58 envs = [self.conda, self.venv, self.sqsh]
59 non_none_count = sum(x is not None for x in envs)
60 if non_none_count != 1:
61 raise ValueError("Exactly one of 'conda', 'venv', or 'sqsh' must be set")
62 return self
65class BaseJob(BaseModel):
66 name: str = Field(default="job", description="Job name")
67 job_id: int | None = Field(default=None, description="SLURM job ID")
68 depends_on: list[str] = Field(
69 default_factory=list, description="Task dependencies for workflow execution"
70 )
72 _status: JobStatus = PrivateAttr(default=JobStatus.PENDING)
74 @property
75 def status(self) -> JobStatus:
76 """
77 Accessing ``job.status`` always triggers a lightweight refresh
78 (only if we have a ``job_id`` and the status isn't terminal).
79 """
80 if self.job_id is not None and self._status not in {
81 JobStatus.COMPLETED,
82 JobStatus.FAILED,
83 JobStatus.CANCELLED,
84 JobStatus.TIMEOUT,
85 }:
86 self.refresh()
87 return self._status
89 @status.setter
90 def status(self, value: JobStatus) -> None:
91 self._status = value
93 def refresh(self, retries: int = 3) -> Self:
94 """Query sacct and update ``_status`` in-place."""
95 if self.job_id is None:
96 return self
98 for retry in range(retries):
99 try:
100 result = subprocess.run(
101 [
102 "sacct",
103 "-j",
104 str(self.job_id),
105 "--format",
106 "JobID,State",
107 "--noheader",
108 "--parsable2",
109 ],
110 capture_output=True,
111 text=True,
112 check=True,
113 )
114 except subprocess.CalledProcessError as e:
115 logger.error(f"Failed to query job {self.job_id}: {e}")
116 raise
118 line = result.stdout.strip().split("\n")[0] if result.stdout.strip() else ""
119 if not line:
120 if retry < retries - 1:
121 time.sleep(1)
122 continue
123 self._status = JobStatus.UNKNOWN
124 return self
125 break
127 _, state = line.split("|", 1)
128 self._status = JobStatus(state)
129 return self
131 def dependencies_satisfied(self, completed_job_names: list[str]) -> bool:
132 """All dependencies are completed & this job is still pending."""
133 return self.status == JobStatus.PENDING and all(
134 dep in completed_job_names for dep in self.depends_on
135 )
138class Job(BaseJob):
139 """Represents a SLURM job with complete configuration."""
141 command: list[str] = Field(description="Command to execute")
142 resources: JobResource = Field(
143 default_factory=JobResource, description="Resource requirements"
144 )
145 environment: JobEnvironment = Field(
146 default_factory=JobEnvironment, description="Environment setup"
147 )
148 log_dir: str = Field(
149 default=os.getenv("SLURM_LOG_DIR", "logs"),
150 description="Directory for log files",
151 )
152 work_dir: str = Field(default_factory=os.getcwd, description="Working directory")
155class ShellJob(BaseJob):
156 path: str = Field(description="Shell script path to execute")
159type JobType = BaseJob | Job | ShellJob
160type RunnableJobType = Job | ShellJob
163class Workflow:
164 """Represents a workflow containing multiple jobs with dependencies."""
166 def __init__(self, name: str, jobs: list[RunnableJobType] | None = None) -> None:
167 if jobs is None:
168 jobs = []
170 self.name = name
171 self.jobs = jobs
173 def add(self, job: RunnableJobType) -> None:
174 # Check if job already exists
175 if job.depends_on:
176 for dep in job.depends_on:
177 if dep not in self.jobs:
178 raise WorkflowValidationError(
179 f"Job '{job.name}' depends on unknown job '{dep}'"
180 )
181 self.jobs.append(job)
183 def remove(self, job: RunnableJobType) -> None:
184 self.jobs.remove(job)
186 def get(self, name: str) -> RunnableJobType | None:
187 """Get a job by name."""
188 for job in self.jobs:
189 if job.name == name:
190 return job.refresh()
191 return None
193 def get_dependencies(self, job_name: str) -> list[str]:
194 """Get dependencies for a specific job."""
195 job = self.get(job_name)
196 return job.depends_on if job else []
198 def show(self):
199 msg = f"""\
200{" PLAN ":=^80}
201Workflow: {self.name}
202Jobs: {len(self.jobs)}
203"""
205 def add_indent(indent: int, msg: str) -> str:
206 return " " * indent + msg
208 for job in self.jobs:
209 msg += add_indent(1, f"Job: {job.name}\n")
210 if isinstance(job, Job):
211 msg += add_indent(
212 2, f"{'Command:': <13} {' '.join(job.command or [])}\n"
213 )
214 msg += add_indent(
215 2,
216 f"{'Resources:': <13} {job.resources.nodes} nodes, {job.resources.gpus_per_node} GPUs/node\n",
217 )
218 if job.environment.conda:
219 msg += add_indent(
220 2, f"{'Conda env:': <13} {job.environment.conda}\n"
221 )
222 if job.environment.sqsh:
223 msg += add_indent(2, f"{'Sqsh:': <13} {job.environment.sqsh}\n")
224 if job.environment.venv:
225 msg += add_indent(2, f"{'Venv:': <13} {job.environment.venv}\n")
226 elif isinstance(job, ShellJob):
227 msg += add_indent(2, f"{'Path:': <13} {job.path}\n")
228 if job.depends_on:
229 msg += add_indent(
230 2, f"{'Dependencies:': <13} {', '.join(job.depends_on)}\n"
231 )
233 msg += f"{'=' * 80}\n"
234 print(msg)
236 def validate(self):
237 """Validate workflow job dependencies."""
238 job_names = {job.name for job in self.jobs}
240 if len(job_names) != len(self.jobs):
241 raise WorkflowValidationError("Duplicate job names found in workflow")
243 for job in self.jobs:
244 for dependency in job.depends_on:
245 if dependency not in job_names:
246 raise WorkflowValidationError(
247 f"Job '{job.name}' depends on unknown job '{dependency}'"
248 )
250 # Check for circular dependencies (simple check)
251 visited = set()
252 rec_stack = set()
254 def has_cycle(job_name: str) -> bool:
255 if job_name in rec_stack:
256 return True
257 if job_name in visited:
258 return False
260 visited.add(job_name)
261 rec_stack.add(job_name)
263 job = self.get(job_name)
264 if job:
265 for dependency in job.depends_on:
266 if has_cycle(dependency):
267 return True
269 rec_stack.remove(job_name)
270 return False
272 for job in self.jobs:
273 if has_cycle(job.name):
274 raise WorkflowValidationError(
275 f"Circular dependency detected involving job '{job.name}'"
276 )
279def render_job_script(
280 template_path: Path | str,
281 job: Job,
282 output_dir: Path | str,
283 verbose: bool = False,
284) -> str:
285 """Render a SLURM job script from a template.
287 Args:
288 template_path: Path to the Jinja template file.
289 job: Job configuration.
290 output_dir: Directory where the generated script will be saved.
291 verbose: Whether to print the rendered content.
293 Returns:
294 Path to the generated SLURM batch script.
296 Raises:
297 FileNotFoundError: If the template file does not exist.
298 jinja2.TemplateError: If template rendering fails.
299 """
300 template_file = Path(template_path)
301 if not template_file.is_file():
302 raise FileNotFoundError(f"Template file '{template_path}' not found")
304 with open(template_file, encoding="utf-8") as f:
305 template_content = f.read()
307 template = jinja2.Template(template_content, undefined=jinja2.StrictUndefined)
309 # Prepare template variables
310 template_vars = {
311 "job_name": job.name,
312 "command": " ".join(job.command or []),
313 "log_dir": job.log_dir,
314 "work_dir": job.work_dir,
315 "environment_setup": _build_environment_setup(job.environment),
316 **job.resources.model_dump(),
317 }
319 rendered_content = template.render(template_vars)
321 if verbose:
322 print(rendered_content)
324 # Generate output file
325 output_path = Path(output_dir) / f"{job.name}.slurm"
326 with open(output_path, "w", encoding="utf-8") as f:
327 f.write(rendered_content)
329 return str(output_path)
332def _build_environment_setup(environment: JobEnvironment) -> str:
333 """Build environment setup script."""
334 setup_lines = []
336 # Set environment variables
337 for key, value in environment.env_vars.items():
338 setup_lines.append(f"export {key}={value}")
340 # Activate environments
341 if environment.conda:
342 home_dir = Path.home()
343 setup_lines.extend(
344 [
345 f"source {str(home_dir)}/miniconda3/bin/activate",
346 "conda deactivate",
347 f"conda activate {environment.conda}",
348 ]
349 )
350 elif environment.venv:
351 setup_lines.append(f"source {environment.venv}/bin/activate")
352 elif environment.sqsh:
353 setup_lines.extend(
354 [
355 f': "${{IMAGE:={environment.sqsh}}}"',
356 "declare -a CONTAINER_ARGS=(",
357 ' --container-image "$IMAGE"',
358 ")",
359 ]
360 )
362 return "\n".join(setup_lines)