#!/usr/bin/env python3
# Copyright (C) 2023 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.

# SVA-Edit #
# Changes based on 2.5.0p2 from file packages/cmk-plugins/cmk/plugins/pure_storage_fa/special_agent/agent_pure_storage_fa.py
# rewrite with SVA Special Agent
# Author: sebastian.haeger@sva.de
# SVA-Edit End #

"""agent_pure_storage_fa

Checkmk special agent for monitoring Pure Storage FlashArray via REST API.
"""

# mypy: disable-error-code="no-any-return"
# mypy: disable-error-code="possibly-undefined"
# mypy: disable-error-code="type-arg"

from __future__ import annotations

import argparse
import json
import logging
import sys
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass
from typing import NamedTuple

import requests
import urllib3

from cmk.password_store.v1_unstable import parser_add_secret_option, resolve_secret_option, Secret
from cmk.server_side_programs.v1_unstable import (
    HostnameValidationAdapter,
    report_agent_crashes,
    vcrtrace,
)

__version__ = "2.5.0p2"

AGENT = "pure_storage_fa"

_LOGGER = logging.getLogger(f"agent_{AGENT}")
# SVA-Edit #
USER_AGENT = f"checkmk-special-purefa-ext-{__version__}"
# SVA-Edit End #
API_TOKEN_OPTION = "apitoken"


class _RestVersion(NamedTuple):
    major: int
    minor: int

    @classmethod
    def from_raw(cls, raw_version: str) -> _RestVersion:
        raw_major, raw_minor = raw_version.split(".", 1)
        return cls(int(raw_major), int(raw_minor))

    def __str__(self) -> str:
        return f"{self.major}.{self.minor}"


_REST_VERSION = _RestVersion(2, 0)


@dataclass(frozen=True, kw_only=True)
class _SectionSpec:
    name: str
    path: str
    min_version: _RestVersion
    params: Mapping[str, str] | None = None


_SECTIONS = [
    _SectionSpec(
        name="arrays",
        path="arrays",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="volumes",
        path="volumes",
        min_version=_RestVersion(2, 0),
    ),
    _SectionSpec(
        name="hardware",
        path="hardware",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="alerts",
        path="alerts",
        min_version=_RestVersion(2, 2),
        params={"filter": "state='open'"},
    ),
    # SVA-Edit #
    _SectionSpec(
        name="network_interfaces",
        path="network-interfaces",
        min_version=_RestVersion(2, 4),
    ),
    _SectionSpec(
        name="hardware_ext",
        path="hardware",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="arrays_ext",
        path="arrays",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="drives",
        path="drives",
        min_version=_RestVersion(2, 4),
    ),
    _SectionSpec(
        name="host_groups",
        path="host-groups/hosts",
        min_version=_RestVersion(2, 0),
    ),
    _SectionSpec(
        name="admins",
        path="admins",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="connections",
        path="array-connections",
        min_version=_RestVersion(2, 4),
    ),
    _SectionSpec(
        name="certificates",
        path="certificates",
        min_version=_RestVersion(2, 4),
    ),
    _SectionSpec(
        name="admins_settings",
        path="admins/settings",
        min_version=_RestVersion(2, 2),
    ),
    _SectionSpec(
        name="dns",
        path="dns",
        min_version=_RestVersion(2, 2),
    )
    # SVA-Edit End #
]


def parse_arguments(argv: Sequence[str] | None) -> argparse.Namespace:
    prog, description = __doc__.split("\n\n", maxsplit=1)
    parser = argparse.ArgumentParser(
        prog=prog, description=description, formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument(
        "--debug",
        "-d",
        action="store_true",
        help="Enable debug mode (keep some exceptions unhandled)",
    )
    parser.add_argument("--verbose", "-v", action="count", default=0)
    parser.add_argument(
        "--vcrtrace",
        "--tracefile",
        default=False,
        action=vcrtrace(
            # This is the result of a refactoring.
            # I did not check if it makes sense for this special agent.
            filter_headers=[("authorization", "****")],
        ),
    )
    parser.add_argument("--timeout", type=int, default=5)
    parser.add_argument(
        "--no-cert-check",
        action="store_true",
        help="""Disables the checking of the servers ssl certificate""",
    )
    parser.add_argument(
        "--cert-server-name",
        help=(
            "Provides this name for SNI and expects this as the server's name"
            " in the ssl certificate. Overrides '--no-cert-check'."
        ),
    )
    parser_add_secret_option(
        parser,
        long=f"--{API_TOKEN_OPTION}",
        required=True,
        help=(
            "Generate the API token through the Purity user interface"
            " (System > Users > Create API Token)"
            " or through the Purity command line interface"
            " (pureadmin create --api-token)"
        ),
    )
    parser.add_argument("server", type=str, help="Host name or IP address")
    return parser.parse_args(argv)


class AuthError(Exception):
    pass


class APIVersionError(Exception):
    pass


class SectionError(Exception):
    pass


class _PureStorageFlashArraySession:
    def __init__(self, server: str, cert_check: bool | str, timeout: int) -> None:
        self._session = requests.Session()
        self._base_url = f"https://{server}"

        self._verify = True
        if cert_check is False:
            # Watch out: we must provide the verify keyword to every individual request call!
            # Else it will be overwritten by the REQUESTS_CA_BUNDLE env variable
            self._verify = False
            urllib3.disable_warnings(category=urllib3.exceptions.InsecureRequestWarning)
        elif isinstance(cert_check, str):
            self._session.mount(self._base_url, HostnameValidationAdapter(cert_check))

        self._timeout = timeout
        self._x_auth_token = ""

    def post(self, path: str, headers: Mapping[str, str]) -> requests.Response:
        request = requests.Request(
            method="POST",
            url=f"{self._base_url}/api/{path}",
            headers=headers,
        )
        prepared_request = self._session.prepare_request(request)
        # Watch out: we must provide the verify keyword to every individual request call!
        # Else it will be overwritten by the REQUESTS_CA_BUNDLE env variable
        settings = self._session.merge_environment_settings(
            url=prepared_request.url, proxies={}, stream=None, verify=self._verify, cert=None
        )
        return self._session.send(
            prepared_request,
            timeout=self._timeout,
            **settings,
        )

    def get(
        self, path: str, headers: Mapping[str, str], params: Mapping[str, str] | None = None
    ) -> requests.Response:
        request = requests.Request(
            method="GET",
            url=f"{self._base_url}/api/{path}",
            headers=headers,
            params=params,
        )
        prepared_request = self._session.prepare_request(request)
        # Watch out: we must provide the verify keyword to every individual request call!
        # Else it will be overwritten by the REQUESTS_CA_BUNDLE env variable
        settings = self._session.merge_environment_settings(
            prepared_request.url, proxies={}, stream=None, verify=self._verify, cert=None
        )
        return self._session.send(
            prepared_request,
            timeout=self._timeout,
            **settings,
        )


class PureStorageFlashArray:
    def __init__(self, server: str, cert_check: bool | str, timeout: int) -> None:
        self._session = _PureStorageFlashArraySession(server, cert_check, timeout)

    def login(self, api_token: Secret[str]) -> None:
        try:
            login_response = self._session.post(
                f"{_REST_VERSION}/login",
                {
                    "Content-Type": "application/json",
                    "User-Agent": USER_AGENT,
                    "api-token": api_token.reveal(),
                },
            )
        except requests.exceptions.ConnectionError as e:
            _LOGGER.error("Login failed: %s", e)
            raise AuthError()

        if login_response.status_code != 200:
            _LOGGER.error(
                "Login failed: %s (%s)",
                login_response.reason,
                login_response.status_code,
            )
            raise AuthError()

        self._x_auth_token = login_response.headers["x-auth-token"]

    def read_latest_api_version(self) -> _RestVersion:
        try:
            api_version_response = self._session.get("api_version", {})
        except requests.exceptions.ConnectionError as e:
            _LOGGER.error("Getting API version failed: %s", e)
            raise APIVersionError()

        if api_version_response.status_code != 200:
            _LOGGER.error(
                "Getting API version failed: %s (%s)",
                api_version_response.reason,
                api_version_response.status_code,
            )
            raise APIVersionError()

        return max(
            v
            for r in api_version_response.json()["version"]
            if _REST_VERSION.major == (v := _RestVersion.from_raw(r)).major
        )

    def collect_section_data(
        self, latest_version: _RestVersion, spec: _SectionSpec
    ) -> tuple[str, Mapping]:
        try:
            section_response = self._session.get(
                f"{latest_version}/{spec.path}",
                headers={
                    "Content-Type": "application/json",
                    "User-Agent": USER_AGENT,
                    "x-auth-token": self._x_auth_token,
                },
                params=spec.params,
            )
        except requests.exceptions.ConnectionError as e:
            _LOGGER.error("Collecting '%s' failed: %s", spec.name, e)
            raise SectionError()

        if section_response.status_code != 200:
            _LOGGER.error(
                "Collecting '%s' failed: %s (%s)",
                spec.name,
                section_response.reason,
                section_response.status_code,
            )
            raise SectionError()

        return section_response.json()


def _filter_applicable_sections(
    latest_version: _RestVersion, sections: Sequence[_SectionSpec]
) -> Iterator[_SectionSpec]:
    for spec in sections:
        if spec.min_version > latest_version:
            _LOGGER.error(
                "Collecting '%s' failed: '%s' > '%s'",
                spec.name,
                spec.min_version,
                latest_version,
            )
            continue

        yield spec


def agent_pure_storage_fa(args: argparse.Namespace) -> int:
    pure_storage_fa = PureStorageFlashArray(
        args.server,
        args.cert_server_name or not args.no_cert_check,
        int(args.timeout),
    )

    try:
        pure_storage_fa.login(resolve_secret_option(args, API_TOKEN_OPTION))
    except AuthError:
        return 1

    try:
        latest_version = pure_storage_fa.read_latest_api_version()
    except APIVersionError:
        return 1

    for spec in _filter_applicable_sections(latest_version, _SECTIONS):
        try:
            data = pure_storage_fa.collect_section_data(latest_version, spec)
        except SectionError:
            if args.debug:
                return 1
        
        sys.stdout.write(f"<<<pure_storage_fa_{spec.name}:sep(0)>>>\n{json.dumps(data)}\n")

    return 0

@report_agent_crashes(AGENT, __version__)
def main() -> int:
    return agent_pure_storage_fa(parse_arguments(sys.argv[1:]))


if __name__ == "__main__":
    sys.exit(main())