# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, you can obtain one at http://mozilla.org/MPL/2.0/.
import mail_builder
from botocore.exceptions import ClientError
from celery.utils.log import get_task_logger
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from atmo.celery import celery
from atmo.clusters.models import Cluster
from atmo.clusters.provisioners import ClusterProvisioner
from .exceptions import SparkJobNotFound, SparkJobNotEnabled
from .models import SparkJob, SparkJobRun, SparkJobRunAlert
logger = get_task_logger(__name__)
@celery.task
def send_expired_mails():
"""
A Celery task the send emails for when a Spark job as expired
(the end_date has passed) to the owner.
"""
expired_spark_jobs = SparkJob.objects.filter(
expired_date__isnull=False,
)
for spark_job in expired_spark_jobs:
message = mail_builder.build_message(
'atmo/jobs/mails/expired.mail', {
'settings': settings,
'spark_job': spark_job,
}
)
message.send()
@celery.task
def expire_jobs():
"""
Periodic task to purge all schedule entries
that don't have a SparkJob instance anymore
"""
expired_spark_jobs = []
for spark_job in SparkJob.objects.lapsed():
with transaction.atomic():
expired_spark_jobs.append([spark_job.identifier, spark_job.pk])
removed = spark_job.expire()
logger.info(
'Spark job %s (%s) is expired.',
spark_job.pk,
spark_job.identifier,
)
if removed:
logger.info(
'Removing expired Spark job %s (%s) from schedule.',
spark_job.pk,
spark_job.identifier,
)
return expired_spark_jobs
@celery.task(max_retries=8, bind=True)
def update_jobs_statuses(self):
"""
A Celery task that updates the status of all active
job runs using the AWS EMR API or retry with an exponential backoff
when a certain number of failures have happened.
This task runs every 15 minutes (900 seconds, see ``CELERY_BEAT_SCHEDULE``
setting), which fits nicely in the backoff decay of 9 tries total
"""
spark_job_runs = SparkJobRun.objects.all()
# get the active (read: not terminated or failed) job runs
active_spark_job_runs = spark_job_runs.active().prefetch_related('spark_job')
logger.debug(
'Updating Spark job runs: %s',
list(active_spark_job_runs.values_list('pk', flat=True))
)
# create a map between the jobflow ids of the latest runs and the jobs
spark_job_run_map = {}
for spark_job_run in active_spark_job_runs:
spark_job_run_map[spark_job_run.jobflow_id] = spark_job_run
# get the created dates of the job runs to limit the ListCluster API call
provisioner = ClusterProvisioner()
runs_created_at = active_spark_job_runs.datetimes('created_at', 'day')
try:
# only fetch a cluster list if there are any runs at all
updated_spark_job_runs = []
if runs_created_at:
earliest_created_at = runs_created_at[0]
logger.debug('Fetching clusters since %s', earliest_created_at)
cluster_list = provisioner.list(created_after=earliest_created_at)
logger.debug('Clusters found: %s', cluster_list)
for cluster_info in cluster_list:
# filter out the clusters that don't relate to the job run ids
spark_job_run = spark_job_run_map.get(cluster_info['jobflow_id'])
if spark_job_run is None:
continue
logger.debug(
'Updating job status for %s, run %s',
spark_job_run.spark_job,
spark_job_run,
)
# update the Spark job run status
with transaction.atomic():
spark_job_run.sync(cluster_info)
updated_spark_job_runs.append(
[spark_job_run.spark_job.identifier, spark_job_run.pk]
)
return updated_spark_job_runs
except ClientError as exc:
self.retry(
exc=exc,
countdown=celery.backoff(self.request.retries),
)
[docs]class SparkJobRunTask(celery.Task):
"""
A Celery task base classes to be used by the
:func:`~atmo.jobs.tasks.run_job` task to simplify testing.
"""
throws = (
SparkJobNotFound,
SparkJobNotEnabled,
)
#: The max number of retries which does not run too long
#: when using the exponential backoff timeouts.
max_retries = 9
[docs] def get_spark_job(self, pk):
"""
Load the Spark job with the given primary key.
"""
spark_job = SparkJob.objects.filter(pk=pk).first()
if spark_job is None:
raise SparkJobNotFound('Cannot find Spark job with pk %s' % pk)
return spark_job
[docs] @transaction.atomic
def sync_run(self, spark_job):
"""
Updates the cluster status of the latest Spark job run,
if available.
"""
if spark_job.latest_run:
logger.debug('Updating Spark job: %s', spark_job)
spark_job.latest_run.sync()
return True
[docs] def check_enabled(self, spark_job):
"""
Checks if the job should be run at all
"""
if not spark_job.is_enabled:
# just ignore this
raise SparkJobNotEnabled(
'Spark job %s is not enabled, ignoring' %
spark_job
)
[docs] @transaction.atomic
def provision_run(self, spark_job, first_run=False):
"""
Actually run the given Spark job.
If this is the first run we'll update the "last_run_at" value
to the start date of the spark_job so Celery beat knows what's
going on.
"""
spark_job.run()
if first_run:
def update_last_run_at():
schedule_entry = spark_job.schedule.get()
if schedule_entry is None:
schedule_entry = spark_job.schedule.add()
schedule_entry.reschedule(last_run_at=spark_job.start_date)
transaction.on_commit(update_last_run_at)
[docs] @transaction.atomic
def unschedule_and_expire(self, spark_job):
"""
Remove the Spark job from the periodic schedule
and send an email to the owner that it was expired.
"""
logger.debug(
'The Spark job %s has expired was removed from the schedule',
spark_job,
)
spark_job.schedule.delete()
spark_job.expire()
[docs] def terminate_and_notify(self, spark_job):
"""
When the Spark job has timed out because it has run longer
than the maximum runtime we will terminate it (and its cluster)
and notify the owner to optimize the Spark job code.
"""
logger.debug(
'The last run of Spark job %s has not finished yet and timed out, '
'terminating it and notifying owner.', spark_job,
)
spark_job.terminate()
message = mail_builder.build_message(
'atmo/jobs/mails/timed_out.mail', {
'settings': settings,
'spark_job': spark_job,
}
)
message.send()
@celery.task(bind=True, base=SparkJobRunTask)
def run_job(self, pk, first_run=False):
"""
Run the Spark job with the given primary key.
See :class:`~atmo.jobs.tasks.SparkJobRunTask` for more details.
"""
try:
# get the Spark job (may fail with exception)
spark_job = self.get_spark_job(pk)
# update the cluster status of the latest Spark job run
updated = self.sync_run(spark_job)
if updated:
spark_job.refresh_from_db()
# check if the Spark job is enabled (may fail with exception)
self.check_enabled(spark_job)
if spark_job.is_runnable:
# if the latest run of the Spark job has finished
if spark_job.is_due:
# if current datetime is between Spark job's start and end date
self.provision_run(spark_job, first_run=first_run)
else:
# otherwise remove the job from the schedule and send
# an email to the Spark job owner
self.unschedule_and_expire(spark_job)
else:
if spark_job.has_timed_out:
# if the job has not finished and timed out
self.terminate_and_notify(spark_job)
else:
# if the job hasn't finished yet and also hasn't timed out yet.
# since the job timeout is limited to 24 hours this case can
# only happen for daily jobs that have a scheduling or processing
# delay, e.g. slow provisioning. we just retry again in a few
# minutes and see if we caught up with the delay
self.retry(countdown=60 * 10)
except ClientError as exc:
self.retry(
exc=exc,
countdown=celery.backoff(self.request.retries),
)
@celery.task
def send_run_alert_mails():
"""
A Celery task that sends an email to the owner when a Spark job run has
failed and records a datetime when it was sent.
"""
with transaction.atomic():
failed_run_alerts = SparkJobRunAlert.objects.select_for_update().filter(
reason_code__in=Cluster.FAILED_STATE_CHANGE_REASON_LIST,
mail_sent_date__isnull=True,
).prefetch_related('run__spark_job__created_by')
failed_jobs = []
for alert in failed_run_alerts:
with transaction.atomic():
failed_jobs.append(alert.run.spark_job.identifier)
message = mail_builder.build_message(
'atmo/jobs/mails/failed_run_alert.mail', {
'alert': alert,
'settings': settings,
}
)
message.send()
alert.mail_sent_date = timezone.now()
alert.save()
return failed_jobs