Coverage for src/srunx/runner.py: 95%
125 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-24 15:16 +0000
1"""Workflow runner for executing YAML-defined workflows with SLURM"""
3import time
4from collections import defaultdict
5from collections.abc import Sequence
6from concurrent.futures import ThreadPoolExecutor
7from pathlib import Path
8from typing import Any, Self
10import yaml
12from srunx.callbacks import Callback
13from srunx.client import Slurm
14from srunx.exceptions import WorkflowValidationError
15from srunx.logging import get_logger
16from srunx.models import (
17 Job,
18 JobEnvironment,
19 JobResource,
20 JobStatus,
21 RunnableJobType,
22 ShellJob,
23 Workflow,
24)
26logger = get_logger(__name__)
29class WorkflowRunner:
30 """Runner for executing workflows defined in YAML with dynamic job scheduling.
32 Jobs are executed as soon as their dependencies are satisfied,
33 rather than waiting for entire dependency levels to complete.
34 """
36 def __init__(
37 self, workflow: Workflow, callbacks: Sequence[Callback] | None = None
38 ) -> None:
39 """Initialize workflow runner.
41 Args:
42 workflow: Workflow to execute.
43 callbacks: List of callbacks for job notifications.
44 """
45 self.workflow = workflow
46 self.slurm = Slurm(callbacks=callbacks)
47 self.callbacks = callbacks or []
49 @classmethod
50 def from_yaml(
51 cls, yaml_path: str | Path, callbacks: Sequence[Callback] | None = None
52 ) -> Self:
53 """Load and validate a workflow from a YAML file.
55 Args:
56 yaml_path: Path to the YAML workflow definition file.
57 callbacks: List of callbacks for job notifications.
59 Returns:
60 WorkflowRunner instance with loaded workflow.
62 Raises:
63 FileNotFoundError: If the YAML file doesn't exist.
64 yaml.YAMLError: If the YAML is malformed.
65 ValidationError: If the workflow structure is invalid.
66 """
67 yaml_file = Path(yaml_path)
68 if not yaml_file.exists():
69 raise FileNotFoundError(f"Workflow file not found: {yaml_path}")
71 with open(yaml_file, encoding="utf-8") as f:
72 data = yaml.safe_load(f)
74 name = data.get("name", "unnamed")
75 jobs_data = data.get("jobs", [])
77 jobs = []
78 for job_data in jobs_data:
79 job = cls.parse_job(job_data)
80 jobs.append(job)
81 return cls(workflow=Workflow(name=name, jobs=jobs), callbacks=callbacks)
83 def get_independent_jobs(self) -> list[RunnableJobType]:
84 """Get all jobs that are independent of any other job."""
85 independent_jobs = []
86 for job in self.workflow.jobs:
87 if not job.depends_on:
88 independent_jobs.append(job)
89 return independent_jobs
91 def run(self) -> dict[str, RunnableJobType]:
92 """Run a workflow with dynamic job scheduling.
94 Jobs are executed as soon as their dependencies are satisfied.
96 Returns:
97 Dictionary mapping job names to completed Job instances.
98 """
99 logger.info(
100 f"🚀 Starting Workflow {self.workflow.name} with {len(self.workflow.jobs)} jobs"
101 )
102 for callback in self.callbacks:
103 callback.on_workflow_started(self.workflow)
105 # Track all jobs and results
106 all_jobs = self.workflow.jobs.copy()
107 results: dict[str, RunnableJobType] = {}
108 running_futures: dict[str, Any] = {}
110 # Build reverse dependency map for efficient lookups
111 dependents = defaultdict(set)
112 for job in all_jobs:
113 for dep in job.depends_on:
114 dependents[dep].add(job.name)
116 def execute_job(job: RunnableJobType) -> RunnableJobType:
117 """Execute a single job."""
118 logger.info(f"🌋 {'SUBMITTED':<12} Job {job.name:<12}")
120 try:
121 result = self.slurm.run(job)
122 return result
123 except Exception as e:
124 raise
126 def on_job_complete(job_name: str, result: RunnableJobType) -> list[str]:
127 """Handle job completion and return newly ready job names."""
128 results[job_name] = result
129 completed_job_names = list(set(results.keys()))
131 # Find newly ready jobs
132 newly_ready = []
133 for dependent_name in dependents[job_name]:
134 dependent_job = next(j for j in all_jobs if j.name == dependent_name)
135 if (
136 dependent_job.status == JobStatus.PENDING
137 and dependent_job.dependencies_satisfied(completed_job_names)
138 ):
139 newly_ready.append(dependent_name)
141 return newly_ready
143 # Execute workflow with ThreadPoolExecutor
144 with ThreadPoolExecutor(max_workers=8) as executor:
145 # Submit initial ready jobs
146 initial_jobs = self.get_independent_jobs()
148 for job in initial_jobs:
149 future = executor.submit(execute_job, job)
150 running_futures[job.name] = future
152 # Process completed jobs and schedule new ones
153 while running_futures:
154 # Check for completed futures
155 completed = []
156 for job_name, future in list(running_futures.items()):
157 if future.done():
158 completed.append((job_name, future))
159 del running_futures[job_name]
161 if not completed:
162 time.sleep(0.1) # Brief sleep to avoid busy waiting
163 continue
165 # Handle completed jobs
166 for job_name, future in completed:
167 try:
168 result = future.result()
169 newly_ready_names = on_job_complete(job_name, result)
171 # Schedule newly ready jobs
172 for ready_name in newly_ready_names:
173 if ready_name not in running_futures:
174 ready_job = next(
175 j for j in all_jobs if j.name == ready_name
176 )
177 new_future = executor.submit(execute_job, ready_job)
178 running_futures[ready_name] = new_future
180 except Exception as e:
181 logger.error(f"❌ Job {job_name} failed: {e}")
182 raise
184 # Verify all jobs completed successfully
185 failed_jobs = [j.name for j in all_jobs if j.status == JobStatus.FAILED]
186 incomplete_jobs = [
187 j.name
188 for j in all_jobs
189 if j.status not in [JobStatus.COMPLETED, JobStatus.FAILED]
190 ]
192 if failed_jobs:
193 logger.error(f"❌ Jobs failed: {failed_jobs}")
194 raise RuntimeError(f"Workflow execution failed: {failed_jobs}")
196 if incomplete_jobs:
197 logger.error(f"❌ Jobs did not complete: {incomplete_jobs}")
198 raise RuntimeError(f"Workflow execution incomplete: {incomplete_jobs}")
200 logger.success(f"🎉 Workflow {self.workflow.name} completed!!")
202 for callback in self.callbacks:
203 callback.on_workflow_completed(self.workflow)
205 return results
207 def execute_from_yaml(self, yaml_path: str | Path) -> dict[str, RunnableJobType]:
208 """Load and execute a workflow from YAML file.
210 Args:
211 yaml_path: Path to YAML workflow file.
213 Returns:
214 Dictionary mapping job names to completed Job instances.
215 """
216 logger.info(f"Loading workflow from {yaml_path}")
217 runner = self.from_yaml(yaml_path)
218 return runner.run()
220 @staticmethod
221 def parse_job(data: dict[str, Any]) -> RunnableJobType:
222 if data.get("path") and data.get("command"):
223 raise WorkflowValidationError("Job cannot have both 'path' and 'command'")
225 base = {"name": data["name"], "depends_on": data.get("depends_on", [])}
227 if data.get("path"):
228 return ShellJob.model_validate({**base, "path": data["path"]})
230 resource = JobResource.model_validate(data.get("resources", {}))
231 environment = JobEnvironment.model_validate(data.get("environment", {}))
233 job_data = {
234 **base,
235 "command": data["command"],
236 "resources": resource,
237 "environment": environment,
238 }
239 if data.get("log_dir"):
240 job_data["log_dir"] = data["log_dir"]
241 if data.get("work_dir"):
242 job_data["work_dir"] = data["work_dir"]
244 return Job.model_validate(job_data)
247def run_workflow_from_file(yaml_path: str | Path) -> dict[str, RunnableJobType]:
248 """Convenience function to run workflow from YAML file.
250 Args:
251 yaml_path: Path to YAML workflow file.
253 Returns:
254 Dictionary mapping job names to completed Job instances.
255 """
256 runner = WorkflowRunner.from_yaml(yaml_path)
257 return runner.run()