Coverage for src/srunx/cli/main.py: 70%

222 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-24 15:16 +0000

1"""Main CLI interface for srunx.""" 

2 

3import argparse 

4import os 

5import sys 

6from pathlib import Path 

7 

8from rich.console import Console 

9from rich.table import Table 

10 

11from srunx.callbacks import SlackCallback 

12from srunx.client import Slurm 

13from srunx.logging import ( 

14 configure_cli_logging, 

15 configure_workflow_logging, 

16 get_logger, 

17) 

18from srunx.models import Job, JobEnvironment, JobResource 

19from srunx.runner import WorkflowRunner 

20 

21logger = get_logger(__name__) 

22 

23 

24def create_job_parser() -> argparse.ArgumentParser: 

25 """Create argument parser for job submission.""" 

26 parser = argparse.ArgumentParser( 

27 description="Submit SLURM jobs with various configurations", 

28 formatter_class=argparse.RawDescriptionHelpFormatter, 

29 ) 

30 

31 # Required arguments 

32 parser.add_argument( 

33 "command", 

34 nargs="+", 

35 help="Command to execute in the SLURM job", 

36 ) 

37 

38 # Job configuration 

39 parser.add_argument( 

40 "--name", 

41 "--job-name", 

42 type=str, 

43 default="job", 

44 help="Job name (default: %(default)s)", 

45 ) 

46 parser.add_argument( 

47 "--log-dir", 

48 type=str, 

49 default=os.getenv("SLURM_LOG_DIR", "logs"), 

50 help="Log directory (default: %(default)s)", 

51 ) 

52 parser.add_argument( 

53 "--work-dir", 

54 "--chdir", 

55 type=str, 

56 help="Working directory for the job", 

57 ) 

58 

59 # Resource configuration 

60 resource_group = parser.add_argument_group("Resource Options") 

61 resource_group.add_argument( 

62 "-N", 

63 "--nodes", 

64 type=int, 

65 default=1, 

66 help="Number of nodes (default: %(default)s)", 

67 ) 

68 resource_group.add_argument( 

69 "--gpus-per-node", 

70 type=int, 

71 default=0, 

72 help="Number of GPUs per node (default: %(default)s)", 

73 ) 

74 resource_group.add_argument( 

75 "--ntasks-per-node", 

76 type=int, 

77 default=1, 

78 help="Number of tasks per node (default: %(default)s)", 

79 ) 

80 resource_group.add_argument( 

81 "--cpus-per-task", 

82 type=int, 

83 default=1, 

84 help="Number of CPUs per task (default: %(default)s)", 

85 ) 

86 resource_group.add_argument( 

87 "--memory", 

88 "--mem", 

89 type=str, 

90 help="Memory per node (e.g., '32GB', '1TB')", 

91 ) 

92 resource_group.add_argument( 

93 "--time", 

94 "--time-limit", 

95 type=str, 

96 help="Time limit (e.g., '1:00:00', '30:00', '1-12:00:00')", 

97 ) 

98 

99 # Environment configuration 

100 env_group = parser.add_argument_group("Environment Options") 

101 env_group.add_argument( 

102 "--conda", 

103 type=str, 

104 help="Conda environment name", 

105 ) 

106 env_group.add_argument( 

107 "--venv", 

108 type=str, 

109 help="Virtual environment path", 

110 ) 

111 env_group.add_argument( 

112 "--sqsh", 

113 type=str, 

114 help="SquashFS image path", 

115 ) 

116 env_group.add_argument( 

117 "--env", 

118 action="append", 

119 dest="env_vars", 

120 help="Environment variable KEY=VALUE (can be used multiple times)", 

121 ) 

122 

123 # Execution options 

124 exec_group = parser.add_argument_group("Execution Options") 

125 exec_group.add_argument( 

126 "--template", 

127 type=str, 

128 help="Path to custom SLURM template file", 

129 ) 

130 exec_group.add_argument( 

131 "--wait", 

132 action="store_true", 

133 help="Wait for job completion", 

134 ) 

135 exec_group.add_argument( 

136 "--poll-interval", 

137 type=int, 

138 default=5, 

139 help="Polling interval in seconds when waiting (default: %(default)s)", 

140 ) 

141 

142 # Logging options 

143 log_group = parser.add_argument_group("Logging Options") 

144 log_group.add_argument( 

145 "--log-level", 

146 choices=["DEBUG", "INFO", "WARNING", "ERROR"], 

147 default="INFO", 

148 help="Set logging level (default: %(default)s)", 

149 ) 

150 log_group.add_argument( 

151 "--quiet", 

152 "-q", 

153 action="store_true", 

154 help="Only show warnings and errors", 

155 ) 

156 

157 # Callback options 

158 callback_group = parser.add_argument_group("Notification Options") 

159 callback_group.add_argument( 

160 "--slack", 

161 action="store_true", 

162 help="Send notifications to Slack", 

163 ) 

164 

165 # Misc options 

166 misc_group = parser.add_argument_group("Misc Options") 

167 misc_group.add_argument( 

168 "--verbose", 

169 action="store_true", 

170 help="Print the rendered content", 

171 ) 

172 

173 return parser 

174 

175 

176def create_status_parser() -> argparse.ArgumentParser: 

177 """Create argument parser for job status.""" 

178 parser = argparse.ArgumentParser( 

179 description="Check SLURM job status", 

180 formatter_class=argparse.RawDescriptionHelpFormatter, 

181 ) 

182 

183 parser.add_argument( 

184 "job_id", 

185 type=int, 

186 help="SLURM job ID to check", 

187 ) 

188 

189 return parser 

190 

191 

192def create_queue_parser() -> argparse.ArgumentParser: 

193 """Create argument parser for queueing jobs.""" 

194 parser = argparse.ArgumentParser( 

195 description="Queue SLURM jobs", 

196 formatter_class=argparse.RawDescriptionHelpFormatter, 

197 ) 

198 

199 parser.add_argument( 

200 "--user", 

201 "-u", 

202 type=str, 

203 help="Queue jobs for specific user (default: current user)", 

204 ) 

205 

206 return parser 

207 

208 

209def create_cancel_parser() -> argparse.ArgumentParser: 

210 """Create argument parser for job cancellation.""" 

211 parser = argparse.ArgumentParser( 

212 description="Cancel SLURM job", 

213 formatter_class=argparse.RawDescriptionHelpFormatter, 

214 ) 

215 

216 parser.add_argument( 

217 "job_id", 

218 type=int, 

219 help="SLURM job ID to cancel", 

220 ) 

221 

222 return parser 

223 

224 

225def create_main_parser() -> argparse.ArgumentParser: 

226 """Create main argument parser with subcommands.""" 

227 parser = argparse.ArgumentParser( 

228 description="srunx - Python library for SLURM job management", 

229 formatter_class=argparse.RawDescriptionHelpFormatter, 

230 ) 

231 

232 # Global options 

233 parser.add_argument( 

234 "--log-level", 

235 "-l", 

236 choices=["DEBUG", "INFO", "WARNING", "ERROR"], 

237 default="INFO", 

238 help="Set logging level (default: %(default)s)", 

239 ) 

240 parser.add_argument( 

241 "--quiet", 

242 "-q", 

243 action="store_true", 

244 help="Only show warnings and errors", 

245 ) 

246 

247 subparsers = parser.add_subparsers(dest="command", help="Available commands") 

248 

249 # Submit command (default) 

250 submit_parser = subparsers.add_parser("submit", help="Submit a SLURM job") 

251 submit_parser.set_defaults(func=cmd_submit) 

252 _copy_parser_args(create_job_parser(), submit_parser) 

253 

254 # Status command 

255 status_parser = subparsers.add_parser("status", help="Check job status") 

256 status_parser.set_defaults(func=cmd_status) 

257 _copy_parser_args(create_status_parser(), status_parser) 

258 

259 # Queue command 

260 queue_parser = subparsers.add_parser("queue", help="Queue jobs") 

261 queue_parser.set_defaults(func=cmd_queue) 

262 _copy_parser_args(create_queue_parser(), queue_parser) 

263 

264 # Cancel command 

265 cancel_parser = subparsers.add_parser("cancel", help="Cancel job") 

266 cancel_parser.set_defaults(func=cmd_cancel) 

267 _copy_parser_args(create_cancel_parser(), cancel_parser) 

268 

269 # Flow command 

270 flow_parser = subparsers.add_parser("flow", help="Workflow management") 

271 flow_parser.set_defaults(func=None) # Will be overridden by subcommands 

272 

273 # Flow subcommands 

274 flow_subparsers = flow_parser.add_subparsers( 

275 dest="flow_command", help="Flow commands" 

276 ) 

277 

278 # Flow run command 

279 flow_run_parser = flow_subparsers.add_parser("run", help="Execute workflow") 

280 flow_run_parser.set_defaults(func=cmd_flow_run) 

281 flow_run_parser.add_argument( 

282 "yaml_file", 

283 type=str, 

284 help="Path to YAML workflow definition file", 

285 ) 

286 flow_run_parser.add_argument( 

287 "--dry-run", 

288 action="store_true", 

289 help="Show what would be executed without running jobs", 

290 ) 

291 flow_run_parser.add_argument( 

292 "--slack", 

293 action="store_true", 

294 help="Send notifications to Slack", 

295 ) 

296 

297 # Flow validate command 

298 flow_validate_parser = flow_subparsers.add_parser( 

299 "validate", help="Validate workflow" 

300 ) 

301 flow_validate_parser.set_defaults(func=cmd_flow_validate) 

302 flow_validate_parser.add_argument( 

303 "yaml_file", 

304 type=str, 

305 help="Path to YAML workflow definition file", 

306 ) 

307 

308 return parser 

309 

310 

311def _copy_parser_args( 

312 source_parser: argparse.ArgumentParser, target_parser: argparse.ArgumentParser 

313) -> None: 

314 """Copy arguments from source parser to target parser.""" 

315 for action in source_parser._actions: 

316 if action.dest == "help": 

317 continue 

318 target_parser._add_action(action) 

319 

320 

321def _parse_env_vars(env_var_list: list[str] | None) -> dict[str, str]: 

322 """Parse environment variables from list of KEY=VALUE strings.""" 

323 env_vars = {} 

324 if env_var_list: 

325 for env_var in env_var_list: 

326 if "=" in env_var: 

327 key, value = env_var.split("=", 1) 

328 env_vars[key] = value 

329 else: 

330 logger.warning(f"Invalid environment variable format: {env_var}") 

331 return env_vars 

332 

333 

334def cmd_submit(args: argparse.Namespace) -> None: 

335 """Handle job submission command.""" 

336 try: 

337 # Parse environment variables 

338 env_vars = _parse_env_vars(getattr(args, "env_vars", None)) 

339 

340 # Create job configuration 

341 resources = JobResource( 

342 nodes=args.nodes, 

343 gpus_per_node=args.gpus_per_node, 

344 ntasks_per_node=args.ntasks_per_node, 

345 cpus_per_task=args.cpus_per_task, 

346 memory_per_node=getattr(args, "memory", None), 

347 time_limit=getattr(args, "time", None), 

348 ) 

349 

350 environment = JobEnvironment( 

351 conda=getattr(args, "conda", None), 

352 venv=getattr(args, "venv", None), 

353 sqsh=getattr(args, "sqsh", None), 

354 env_vars=env_vars, 

355 ) 

356 

357 job_data = { 

358 "name": args.name, 

359 "command": args.command, 

360 "resources": resources, 

361 "environment": environment, 

362 "log_dir": args.log_dir, 

363 } 

364 

365 if args.work_dir is not None: 

366 job_data["work_dir"] = args.work_dir 

367 

368 job = Job.model_validate(job_data) 

369 

370 if args.slack: 

371 webhook_url = os.getenv("SLACK_WEBHOOK_URL") 

372 if not webhook_url: 

373 raise ValueError("SLACK_WEBHOOK_URL is not set") 

374 callbacks = [SlackCallback(webhook_url=webhook_url)] 

375 else: 

376 callbacks = [] 

377 

378 # Submit job 

379 client = Slurm(callbacks=callbacks) 

380 submitted_job = client.submit( 

381 job, getattr(args, "template", None), verbose=args.verbose 

382 ) 

383 

384 logger.info(f"Submitted job {submitted_job.job_id}: {submitted_job.name}") 

385 

386 # Wait for completion if requested 

387 if getattr(args, "wait", False): 

388 logger.info(f"Waiting for job {submitted_job.job_id} to complete...") 

389 completed_job = client.monitor( 

390 submitted_job, poll_interval=args.poll_interval 

391 ) 

392 status_str = ( 

393 completed_job.status.value if completed_job.status else "Unknown" 

394 ) 

395 logger.info( 

396 f"Job {submitted_job.job_id} completed with status: {status_str}" 

397 ) 

398 

399 except Exception as e: 

400 logger.error(f"Error submitting job: {e}") 

401 sys.exit(1) 

402 

403 

404def cmd_status(args: argparse.Namespace) -> None: 

405 """Handle job status command.""" 

406 try: 

407 client = Slurm() 

408 job = client.retrieve(args.job_id) 

409 

410 logger.info(f"Job ID: {job.job_id}") 

411 logger.info(f"Name: {job.name}") 

412 if job.status: 

413 logger.info(f"Status: {job.status.value}") 

414 else: 

415 logger.info("Status: Unknown") 

416 

417 except Exception as e: 

418 logger.error(f"Error getting job status: {e}") 

419 sys.exit(1) 

420 

421 

422def cmd_queue(args: argparse.Namespace) -> None: 

423 """Handle job queueing command.""" 

424 try: 

425 client = Slurm() 

426 jobs = client.queue(getattr(args, "user", None)) 

427 

428 if not jobs: 

429 logger.info("No jobs found") 

430 return 

431 

432 logger.info(f"{'Job ID':<12} {'Name':<20} {'Status':<12}") 

433 logger.info("-" * 45) 

434 for job in jobs: 

435 status_str = job.status.value if job.status else "Unknown" 

436 logger.info(f"{job.job_id:<12} {job.name:<20} {status_str:<12}") 

437 

438 except Exception as e: 

439 logger.error(f"Error queueing jobs: {e}") 

440 sys.exit(1) 

441 

442 

443def cmd_cancel(args: argparse.Namespace) -> None: 

444 """Handle job cancellation command.""" 

445 try: 

446 client = Slurm() 

447 client.cancel(args.job_id) 

448 logger.info(f"Cancelled job {args.job_id}") 

449 

450 except Exception as e: 

451 logger.error(f"Error cancelling job: {e}") 

452 sys.exit(1) 

453 

454 

455def cmd_flow_run(args: argparse.Namespace) -> None: 

456 """Handle flow run command.""" 

457 # Configure logging for workflow execution 

458 configure_workflow_logging(level=getattr(args, "log_level", "INFO")) 

459 

460 try: 

461 yaml_file = Path(args.yaml_file) 

462 if not yaml_file.exists(): 

463 logger.error(f"Workflow file not found: {args.yaml_file}") 

464 sys.exit(1) 

465 

466 # Setup callbacks if requested 

467 callbacks = [] 

468 if getattr(args, "slack", False): 

469 webhook_url = os.getenv("SLACK_WEBHOOK_URL") 

470 if not webhook_url: 

471 raise ValueError("SLACK_WEBHOOK_URL environment variable is not set") 

472 callbacks.append(SlackCallback(webhook_url=webhook_url)) 

473 

474 runner = WorkflowRunner.from_yaml(yaml_file, callbacks=callbacks) 

475 

476 # Validate dependencies 

477 runner.workflow.validate() 

478 

479 if args.dry_run: 

480 runner.workflow.show() 

481 return 

482 

483 # Execute workflow 

484 results = runner.run() 

485 

486 logger.success(f"🎉 Workflow {runner.workflow.name} completed!!") 

487 table = Table(title=f"Workflow {runner.workflow.name} Summary") 

488 table.add_column("Job", justify="left", style="cyan", no_wrap=True) 

489 table.add_column("Status", justify="left", style="cyan", no_wrap=True) 

490 table.add_column("ID", justify="left", style="cyan", no_wrap=True) 

491 for job in results.values(): 

492 table.add_row(job.name, job.status.value, str(job.job_id)) 

493 

494 console = Console() 

495 console.print(table) 

496 

497 except Exception as e: 

498 logger.error(f"Workflow execution failed: {e}") 

499 sys.exit(1) 

500 

501 

502def cmd_flow_validate(args: argparse.Namespace) -> None: 

503 """Handle flow validate command.""" 

504 # Configure logging for workflow validation 

505 configure_workflow_logging(level=getattr(args, "log_level", "INFO")) 

506 

507 try: 

508 yaml_file = Path(args.yaml_file) 

509 if not yaml_file.exists(): 

510 logger.error(f"Workflow file not found: {args.yaml_file}") 

511 sys.exit(1) 

512 

513 runner = WorkflowRunner.from_yaml(yaml_file) 

514 

515 # Validate dependencies 

516 runner.workflow.validate() 

517 

518 logger.info("Workflow validation successful") 

519 

520 except Exception as e: 

521 logger.error(f"Workflow validation failed: {e}") 

522 sys.exit(1) 

523 

524 

525def main() -> None: 

526 """Main entry point for the CLI.""" 

527 parser = create_main_parser() 

528 args = parser.parse_args() 

529 

530 # Configure logging 

531 log_level = getattr(args, "log_level", "INFO") 

532 quiet = getattr(args, "quiet", False) 

533 configure_cli_logging(level=log_level, quiet=quiet) 

534 

535 # If no command specified, default to submit behavior for backward compatibility 

536 if not hasattr(args, "func") or args.func is None: 

537 # Check if this is a flow command without subcommand 

538 if hasattr(args, "command") and args.command == "flow": 

539 if not hasattr(args, "flow_command") or args.flow_command is None: 

540 logger.error("Flow command requires a subcommand (run or validate)") 

541 parser.print_help() 

542 sys.exit(1) 

543 else: 

544 # Try to parse as submit command 

545 submit_parser = create_job_parser() 

546 try: 

547 submit_args = submit_parser.parse_args() 

548 cmd_submit(submit_args) 

549 except SystemExit: 

550 parser.print_help() 

551 sys.exit(1) 

552 else: 

553 args.func(args) 

554 

555 

556if __name__ == "__main__": 

557 main()