#!/usr/bin/env python3
"""
Kuhn & Rueß GmbH
Consulting and Development
https://kuhn-ruess.de

Special agent: pull AWS Lambda metrics from CloudWatch via boto3.

This agent fetches the relevant CloudWatch metrics (invocations, errors,
throttles, duration) per function directly and emits one JSON object per
function under section header <<<aws_lambda_cw>>>.

Required IAM permissions: cloudwatch:ListMetrics, cloudwatch:GetMetricData.
"""
import argparse
import json
import sys
from datetime import datetime, timedelta, timezone


def parse_arguments(argv):
    parser = argparse.ArgumentParser(description="AWS Lambda CloudWatch monitoring")
    parser.add_argument("--access-key-id", required=True, help="AWS access key id")
    parser.add_argument("--secret-key", required=True, help="AWS secret access key")
    parser.add_argument("--region", default="eu-central-1", help="AWS region (default: eu-central-1)")
    parser.add_argument(
        "--role-arn",
        help="Optional IAM role ARN to assume via STS before reading CloudWatch. "
        "Use when the access key itself has no CloudWatch permissions and only "
        "serves to assume a (possibly cross-account) monitoring role.",
    )
    parser.add_argument(
        "--external-id",
        help="Optional ExternalId passed to sts:AssumeRole (only with --role-arn).",
    )
    parser.add_argument(
        "--role-session-name",
        default="checkmk-aws-lambda-cw",
        help="Session name for the assumed role (default: checkmk-aws-lambda-cw).",
    )
    parser.add_argument(
        "--function",
        action="append",
        default=[],
        metavar="NAME",
        help="Limit to this function name. May be given multiple times. "
        "If omitted, all functions reporting metrics are discovered.",
    )
    parser.add_argument(
        "--interval",
        type=int,
        default=600,
        help="Look-back window in seconds the metrics are aggregated over (default: 600).",
    )
    return parser.parse_args(argv)


# Metric name -> (statistic, result key)
_METRICS = [
    ("Invocations", "Sum", "invocations"),
    ("Errors", "Sum", "errors"),
    ("Throttles", "Sum", "throttles"),
    ("Duration", "Average", "duration_avg"),
    ("Duration", "Maximum", "duration_max"),
]


def _discover_functions(client):
    names = set()
    paginator = client.get_paginator("list_metrics")
    for page in paginator.paginate(Namespace="AWS/Lambda", MetricName="Invocations"):
        for metric in page.get("Metrics", []):
            for dim in metric.get("Dimensions", []):
                if dim.get("Name") == "FunctionName":
                    names.add(dim["Value"])
    return sorted(names)


def _build_queries(functions):
    queries = []
    index = {}
    for f_idx, function in enumerate(functions):
        for m_idx, (metric_name, stat, key) in enumerate(_METRICS):
            qid = f"q{f_idx}_{m_idx}"
            index[qid] = (function, key)
            queries.append({
                "Id": qid,
                "MetricStat": {
                    "Metric": {
                        "Namespace": "AWS/Lambda",
                        "MetricName": metric_name,
                        "Dimensions": [{"Name": "FunctionName", "Value": function}],
                    },
                    "Period": 0,  # filled in per call
                    "Stat": stat,
                },
                "ReturnData": True,
            })
    return queries, index


def _chunked(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i:i + size]


def main(argv):
    args = parse_arguments(argv)
    print("<<<aws_lambda_cw:sep(0)>>>")

    try:
        import boto3
    except ImportError:
        sys.stderr.write("boto3 is not available in the site Python\n")
        return 1

    try:
        client_kwargs = {
            "aws_access_key_id": args.access_key_id,
            "aws_secret_access_key": args.secret_key,
            "region_name": args.region,
        }
        if args.role_arn:
            # The access key only serves to assume the (possibly cross-account)
            # monitoring role; the temporary credentials carry the CloudWatch
            # permissions.
            sts = boto3.client("sts", **client_kwargs)
            assume_kwargs = {
                "RoleArn": args.role_arn,
                "RoleSessionName": args.role_session_name,
            }
            if args.external_id:
                assume_kwargs["ExternalId"] = args.external_id
            creds = sts.assume_role(**assume_kwargs)["Credentials"]
            client_kwargs = {
                "aws_access_key_id": creds["AccessKeyId"],
                "aws_secret_access_key": creds["SecretAccessKey"],
                "aws_session_token": creds["SessionToken"],
                "region_name": args.region,
            }
        client = boto3.client("cloudwatch", **client_kwargs)
    except Exception as exc:  # noqa: BLE001 - report any setup failure to the check
        sys.stderr.write(f"could not create CloudWatch client: {exc}\n")
        return 1

    try:
        functions = args.function or _discover_functions(client)
    except Exception as exc:  # noqa: BLE001
        sys.stderr.write(f"ListMetrics failed: {exc}\n")
        return 1

    if not functions:
        # No error, just nothing to monitor yet.
        return 0

    end = datetime.now(timezone.utc)
    start = end - timedelta(seconds=args.interval)
    queries, index = _build_queries(functions)
    for q in queries:
        q["MetricStat"]["Period"] = args.interval

    results = {f: {"name": f, "region": args.region} for f in functions}
    try:
        # GetMetricData accepts at most 500 queries per call.
        for batch in _chunked(queries, 500):
            paginator = client.get_paginator("get_metric_data")
            for page in paginator.paginate(
                MetricDataQueries=batch,
                StartTime=start,
                EndTime=end,
                ScanBy="TimestampDescending",
            ):
                for series in page.get("MetricDataResults", []):
                    function, key = index[series["Id"]]
                    values = series.get("Values", [])
                    if values:
                        results[function][key] = values[0]
    except Exception as exc:  # noqa: BLE001
        sys.stderr.write(f"GetMetricData failed: {exc}\n")
        return 1

    # Sum-based metrics default to 0 when CloudWatch returned no datapoint
    # (= the function was simply not invoked in the window).
    for data in results.values():
        for metric_name, stat, key in _METRICS:
            if stat == "Sum":
                data.setdefault(key, 0.0)
        print(json.dumps(data, sort_keys=True))

    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
