#!/usr/bin/env python

# get a list of all jobs in the queue
# catalog the jobs by their stage id, which should be in the 'comment' field
# (or name field for the initial job)
# if a stage has only HELD jobs, they should be removed since they will be held
# up indefinitely

import json
import subprocess
import re
import sys


def get_stage_id(job_desc):
    """given a job description, pull the stage id from the comment and return it
    if no stage_id is found, return None"""
    
    stage_id = None
        
    if 'comment' not in job_desc:
        return stage_id

    comment = job_desc['comment']
    
    stage_id_search = re.search('stage_id=(.*?);', comment)
    
    if stage_id_search:
        stage_id = stage_id_search.group(1)
    
    return stage_id
    
    
squeue_cmd = 'squeue --json'
   
print("executing: " + squeue_cmd)
result = subprocess.run(squeue_cmd, stdout=subprocess.PIPE, shell=True)

queued_jobs = json.loads(result.stdout.decode())
# now catalog the jobs by their stage_id, which should be in the comment
# if any of the jobs are NOT in the 'JobHeldUser' state, we can remove
# the stage altogether

# a dict tracking stages with at least 1 running or queued job
running_stages = dict()
jobs_by_stage_id = dict()

if 'jobs' not in queued_jobs:
    print("Error: 'jobs' not found in squeue response")
    sys.exit(1)

num_queued_jobs = str(len(queued_jobs['jobs']))
print(f"retrieved {num_queued_jobs} job descriptions")

for j in queued_jobs['jobs']:
    stage_id = get_stage_id(j)
    
    # if there is no stage id, we skip cataloging the job
    if not stage_id:
        continue

    # we've already encountered this stage id and it has a running process
    # or a 'PENDING' process without state_reason = 'JobHeldUser', we can
    # skip cataloging it
    if stage_id in running_stages:
        continue
    
    job_state = None
    if 'job_state' in j:
        job_state = j['job_state']
        
    state_reason = None
    if 'state_reason' in j:
        state_reason = j['state_reason']

    # if the job_state is not 'RUNNING' or 'PENDING', let's skip
    # cataloging it - the full list of job states is here: https://slurm.schedmd.com/squeue.html
    catalog_states = set(['RUNNING','PENDING'])
    if job_state not in catalog_states:
        continue

    # if there is a running or queued job that is not being held
    # remove this job from the list of jobs to cancel
    if job_state == 'RUNNING' or state_reason != 'JobHeldUser':
        running_stages[stage_id] = 1
        if stage_id in jobs_by_stage_id:
            del jobs_by_stage_id[stage_id]
    else:
        if stage_id not in jobs_by_stage_id:
            jobs_by_stage_id[stage_id] = list()
        jobs_by_stage_id[stage_id].append(j['job_id'])
        

# if we get here, all of the stages in jobs_by_stage_id should have
# their jobs canceled
job_ids_to_cancel = list()
for stage_id in jobs_by_stage_id:
    job_ids_to_cancel += jobs_by_stage_id[stage_id]

num_jobs_to_cancel = str(len(job_ids_to_cancel))

if num_jobs_to_cancel == "0":
    print("No jobs meeting criteria to be canceled")

else:

    job_id_string = ','.join(map(str,job_ids_to_cancel))

    print(f"attempting to cancel {num_jobs_to_cancel} jobs")

    scancel_cmd = f"scancel {job_id_string}"

    print("executing: " + scancel_cmd)
    
    result = subprocess.run(scancel_cmd, stdout=subprocess.PIPE, shell=True)

    print(f"result: {result.stdout.decode()}")
