#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# License: GNU General Public License v2
#
#
# Author: thl-cmk[at]outlook[dot]com
# URL   : https://thl-cmk.hopto.org
# Date  : 2024-04-21
# File  : active_checks_radius.py
#
# Active check to monitor radius servers.
#
# https://github.com/pyradius/pyrad
#
# 2024-04-29: removed dictionary.freeradius
#

from argparse import (
    ArgumentDefaultsHelpFormatter,
    ArgumentParser,
    ArgumentTypeError,
    Namespace,
)
from os import environ
from socket import error as socket_error
from sys import (
    argv as sys_argv,
    exit as sys_exit,
    stdout as sys_stdout,
)
from time import time_ns
from typing import Sequence, Tuple

import cmk.utils.password_store

no_radius_lib = False
try:
    from pyrad.client import Client as radClient
    from pyrad.dictionary import Dictionary as radDictionary
    from pyrad.packet import AccessAccept, AccessReject, AccessRequest
    from pyrad.client import Timeout as pyTimeout
except ModuleNotFoundError:
    no_radius_lib = True


class Args(Namespace):
    host: str
    auth_port: int
    secret: str
    timeout: int
    username: str
    password: str
    num_resp_attributes: int
    state_wrong_num_resp_attributes: int
    request_attribute: Tuple[str, str]
    max_response_time: Tuple[int, int]
    expected_response: int
    state_not_expected_response: int


VERSION = '0.1.1-20240428'

cmk_state = {
    0: '',
    1: '(!)',
    2: '(!!)',
    3: '(?)',
}
response_str = {
    2: 'accept',
    3: 'reject',
}


def parse_arguments(argv: Sequence[str]) -> Args:
    def _av_pair(s):
        try:
            name, value = s.split(':', 1)
            return name, value
        except ValueError:
            raise ArgumentTypeError("AV-Pairs must be in the form of name:vale")

    def _levels(s) -> Tuple[int, int]:
        try:
            warn, crit = s.split(',')
            warn = int(warn)
            crit = int(crit)
            return warn, crit
        except ValueError:
            raise ArgumentTypeError("Levels must be in the form 'warn,crit' value")

    parser = ArgumentParser(
        description='This is a (very) basic active RADIUS check for Check_mk. Tests if a RADIUS server is responsive '
                    '(accept/reject/timeout). There is (limited) support to add AV-pairs to the RADIUS request.',
        formatter_class=ArgumentDefaultsHelpFormatter,
        epilog=f'(c) thl-cmk[at]outlook[dot], Version: {VERSION}, For more information see: https://thl-cmk.hopto.org'
    )
    #
    # required request parameters
    #
    parser.add_argument(
        '-H', '--host', required=True,
        help='Host/IP-Address of RADIUS server to query (required)',
    )
    parser.add_argument(
        '--secret', required=True,
        help='secret RADIUS key',
    )
    parser.add_argument(
        '--username', default='dummyuser',
        help='user name to test with',
    )
    parser.add_argument(
        '--password', default='dummypassword',
        help='user password to test with',
    )
    #
    # optional request parameters
    #
    parser.add_argument(
        '--auth-port', type=int, default=1812,
        help='RADIUS authentication port to use.',
    )
    parser.add_argument(
        '--timeout', type=int, default=1,
        help='RADIUS server timeout',
    )
    parser.add_argument(
        '--request-attribute', nargs='*', type=_av_pair, action='append', default=[],
        help='add request attribute in the form of "attribute-name:attribute-value" '
             'ie: "Called-Station-Id:AA-BB-CC-DD-EE-FF". Repeat to add more attributes. '
             'For valid attributes the dictionary file.',
    )
    #
    # response parameters
    #
    parser.add_argument(
        '--expected-response', type=int, choices=[2, 3],
        help=' 2 -> Accepted, 3 -> Rejected',
    )
    parser.add_argument(
        '--state-not-expected-response', type=int, choices=[0, 1, 2, 3], default=2,
        help='Monitoring state: 0 -> OK, 1 -> WARN, 2 -> CRIT, 3 -> UNKNOWN',
    )
    parser.add_argument(
        '--num-resp-attributes', type=int,
        help='Expected number of response attributes',
    )
    parser.add_argument(
        '--state-wrong-num-resp-attributes', type=int, choices=[0, 1, 2, 3], default=1,
        help='Monitoring state: 0 -> OK, 1 -> WARN, 2 -> CRIT, 3 -> UNKNOWN',
    )
    parser.add_argument(
        '--max-response-time', type=_levels,
        help='Upper levels for response time in ms in the format WARN,CRIT time. ie: 10,50'
    )

    args = parser.parse_args(argv)
    args.host = args.host.strip(' ')
    return args


def main(args=None):
    if args is None:
        args = sys_argv[1:]  # without the path/plugin itself

    args = parse_arguments(args)

    if no_radius_lib:
        sys_stdout.write(
            'To use this check plugin you need to install the python pyrad lib in your CMK python environment.(?)\n'
        )
        sys_exit(3)

    omd_root = environ["OMD_ROOT"]
    info_text = []
    long_output = []
    perf_data = []
    status = 0

    rad_server = radClient(
        server=args.host,
        authport=args.auth_port,
        secret=args.secret.encode('utf-8'),
        # freeradius dictionaries are under /usr/share/freeradius/
        dict=radDictionary(f'{omd_root}/local/lib/python3/cmk_addons/plugins/check_radius/libexec/dictionary'),
        timeout=args.timeout,
    )

    rad_req = rad_server.CreateAuthPacket(
        code=AccessRequest,
        User_Name=args.username,
        NAS_Identifier=args.host,
    )
    rad_req["User-Password"] = rad_req.PwCrypt(args.password)

    # add optional request attributes
    for av_pair in args.request_attribute:
        name, value = av_pair[0]
        try:
            rad_req.AddAttribute(name, value)
        except TypeError:
            sys_stdout.write(
                f'WARNING: attribute value must be the real value not the name of the '
                f'value: {value}{cmk_state[1]}{cmk_state[1]}'
            )
            status = max(status, 1)
            continue

    before_request_time = time_ns()
    try:
        response = rad_server.SendPacket(rad_req)
    except pyTimeout as e:
        status = max(status, 2)
        message = f'Radius request timeout{cmk_state[2]}'
        info_text.append(message)
        long_output.append(f'{message}\n{e}')
    except socket_error as e:
        status = max(status, 2)
        message = f'Network error{cmk_state[2]}'
        info_text.append(message)
        long_output.append(f'{message}\n{e}')
    else:
        # first: calculate response time
        response_time = (time_ns() - before_request_time) / 1000 / 1000 / 1000  # -> ns to seconds

        #
        # second: check response code
        message = f'Response: access {response_str.get(response.code, f"unknown ({response.code})")}'
        if args.expected_response and response.code != args.expected_response:
            message += f' (expected: {response_str[args.expected_response]}{cmk_state[args.state_not_expected_response]})'
            status = max(status, args.state_not_expected_response)
        info_text.append(message)
        long_output.append(message)

        # third: check response time
        message = f'Response time {response_time * 1000:.0f} ms'
        if args.max_response_time:
            warn, crit = args.max_response_time
            warn = warn
            crit = crit
            if response_time >= warn / 1000:
                message += f' (WARN/CRIT at {warn}/{crit}'
                if response_time >= crit / 1000:
                    message += cmk_state[2]
                    status = max(status, 2)
                else:
                    message += cmk_state[1]
                    status = max(status, 1)
            perf_data.append(f'radius_response_time={response_time};{warn};{crit};;')
        else:
            perf_data.append(f'radius_response_time={response_time}')
        info_text.append(message)
        long_output.append(message)

        if response.code == AccessAccept:
            #
            # forth: check return attributes
            if response.has_key:
                message = f'Number of attributes in response: {len(response.keys())}'
            else:
                message = long_output.append('No return attributes in response')

            if args.num_resp_attributes and len(response.keys()) != args.num_resp_attributes:
                message += f' (expected {args.num_resp_attributes}{cmk_state[args.state_wrong_num_resp_attributes]})'
                status = max(status, args.state_wrong_num_resp_attributes)
                info_text.append(message)

            long_output.append(message)

            if response.has_key:
                long_output.append('\nResponse attributes:')
                for key in response.keys():
                    long_output.append(f'{key}: {response.get(key)}')

    #
    # format output data
    info_text = ', '.join(info_text)
    long_output = '\n'.join(long_output)
    perf_data = '|'.join(perf_data)
    if perf_data:
        sys_stdout.write(f'{info_text}\n{long_output}|{perf_data}\n')
    else:
        sys_stdout.write(f'{info_text}\n{long_output}')
    return status


if __name__ == '__main__':
    cmk.utils.password_store.replace_passwords()
    exitcode = main()
    sys_exit(exitcode)
