Configuring Apache Airflow for ML Pipelines
Apache Airflow is a mature DAG pipeline orchestrator. It's used for machine learning where flexibility in mixing data engineering and machine learning steps is needed, or where Airflow is already used for other ETL tasks.
Airflow vs. Kubeflow for ML
Airflow is preferred when: already used in data engineering, integration with ML and non-ML tasks in a single DAG is needed, the team knows Airflow.
Kubeflow Pipelines are preferred when: the team is ML-centric, native ML primitives (metrics, artifacts) are needed, Kubernetes-native workflow.
Installation with KubernetesExecutor
# Установка через Helm (рекомендуется)
helm repo add apache-airflow https://airflow.apache.org
helm upgrade --install airflow apache-airflow/airflow \
--namespace airflow \
--create-namespace \
--set executor=KubernetesExecutor \
--set config.logging.logging_level=INFO \
--values airflow-values.yaml
ML pipeline as Airflow DAG
from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from datetime import datetime, timedelta
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"on_failure_callback": notify_on_slack,
}
with DAG(
"fraud_detection_training",
default_args=default_args,
schedule="0 2 * * 1", # по понедельникам в 2:00
start_date=datetime(2024, 1, 1),
catchup=False,
tags=["ml", "fraud-detection"],
) as dag:
# Подготовка данных — на обычном поде
prepare_data = KubernetesPodOperator(
task_id="prepare_data",
image="ml-pipeline:latest",
cmds=["python", "prepare_data.py"],
arguments=["--date={{ ds }}", "--output=s3://bucket/features/{{ ds }}/"],
namespace="ml-pipelines",
resources={"request_memory": "4Gi", "request_cpu": "2"},
get_logs=True,
is_delete_operator_pod=True,
)
# Обучение — на GPU поде
train_model = KubernetesPodOperator(
task_id="train_model",
image="ml-pipeline-gpu:latest",
cmds=["python", "train.py"],
arguments=[
"--data=s3://bucket/features/{{ ds }}/",
"--run-name=fraud-{{ ds }}",
],
namespace="ml-pipelines",
resources={
"request_memory": "32Gi",
"request_cpu": "8",
"limit_gpu": "1",
},
annotations={"nvidia.com/gpu": "1"},
tolerations=[{"key": "nvidia.com/gpu", "operator": "Exists", "effect": "NoSchedule"}],
get_logs=True,
)
# Evaluation gate — Python оператор (дешево)
def check_model_quality(**context):
import mlflow
client = mlflow.tracking.MlflowClient()
run = client.search_runs(
experiment_ids=[EXPERIMENT_ID],
filter_string=f"tags.run_date = '{context['ds']}'",
order_by=["metrics.f1 DESC"],
max_results=1
)[0]
f1 = run.data.metrics.get("test_f1", 0)
if f1 < 0.90:
raise ValueError(f"Model quality too low: F1={f1:.3f} < 0.90")
context["ti"].xcom_push(key="run_id", value=run.info.run_id)
quality_gate = PythonOperator(
task_id="quality_gate",
python_callable=check_model_quality,
)
# Промоция — только если quality_gate прошёл
promote_model = KubernetesPodOperator(
task_id="promote_to_staging",
image="ml-pipeline:latest",
cmds=["python", "promote_model.py"],
arguments=["--run-id={{ ti.xcom_pull(task_ids='quality_gate', key='run_id') }}"],
namespace="ml-pipelines",
)
# Зависимости
prepare_data >> train_model >> quality_gate >> promote_model
TaskFlow API (modern approach)
from airflow.decorators import dag, task
@dag(schedule="0 2 * * 1", start_date=datetime(2024, 1, 1))
def ml_pipeline():
@task
def prepare_data(execution_date: str) -> str:
# Подготовка данных
return f"s3://bucket/features/{execution_date}/"
@task
def train_model(data_path: str) -> dict:
# Запуск обучения (или триггер внешнего job)
return {"run_id": "xxx", "f1": 0.924}
@task
def promote_if_good(metrics: dict) -> None:
if metrics["f1"] >= 0.90:
promote_to_staging(metrics["run_id"])
data = prepare_data()
metrics = train_model(data)
promote_if_good(metrics)
ml_pipeline()
Airflow DAG Monitoring
The Airflow UI displays the status of each run, the duration of each task, and logs. Prometheus integration via airflow-exporter: airflow_dag_run_duration_seconds, airflow_task_fail_count. Alerts on failed tasks via Slack/PagerDuty via on_failure_callback.







