#!/usr/bin/env python3

# +------------------------------------------------------------+
# |                                                            |
# |             | |             | |            | |             |
# |          ___| |__   ___  ___| | ___ __ ___ | | __          |
# |         / __| '_ \ / _ \/ __| |/ / '_ ` _ \| |/ /          |
# |        | (__| | | |  __/ (__|   <| | | | | |   <           |
# |         \___|_| |_|\___|\___|_|\_\_| |_| |_|_|\_\          |
# |                                   custom code by SVA       |
# |                                                            |
# +------------------------------------------------------------+
#
#   This program is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   Copyright (C) 2026  SVA System Vertrieb Alexander GmbH
#                       by sebastian.haeger@sva.de
#
#   Last modified: 17.04.2026

import argparse
import sys
import socket
import json
import getpass
from enum import Enum
from pprint import pprint
import re
from typing import Dict

class WhichIsRegex(Enum):
    NO = -1
    KEY = 0
    VALUE = 1
    KEYANDVALUE = 2


class AUX_Host_Labels:
    def __init__(self, labels):
        self.labels = [tuple(x) for x in json.loads(labels)]
        self.site = getpass.getuser()
        self.address = f"/omd/sites/{self.site}/tmp/run/live"
        self.piggyback_data: dict[str, dict[str, str]] = {}

    def format_label_for_lq(self, label: str, isregex: WhichIsRegex):
        key, value = label.split(":")
        if isregex is WhichIsRegex.NO:
            return (key.strip(), value.strip())
        if isregex is not WhichIsRegex.NO:
            key = key.removeprefix("~").strip()
            value = value.removeprefix("~").strip()
            return (key, value)

    def get_key_value_is_regex(self, label: str) -> WhichIsRegex:
        key, value = label.split(":")
        which_is_regex = WhichIsRegex.NO
        if key != key.removeprefix("~"):
            which_is_regex = WhichIsRegex.KEY
        if value != value.removeprefix("~"):
            which_is_regex = WhichIsRegex.VALUE
        if value != value.removeprefix("~") and key != key.removeprefix("~"):
            which_is_regex = WhichIsRegex.KEYANDVALUE
        return which_is_regex

    def create_piggyback(self, host, new_label):
        key, value = new_label.split(":", 1)

        if host not in self.piggyback_data:
            self.piggyback_data[host] = {}

        self.piggyback_data[host][key] = value

    def print_piggyback(self) -> None:
        for host, labels in self.piggyback_data.items():
            sys.stdout.write(f"<<<<{host}>>>>\n")
            sys.stdout.write("<<<labels:sep(0)>>>\n")
            sys.stdout.write("{")
            for i, (key, value) in enumerate(labels.items()):
                if i > 0:
                    sys.stdout.write(",")
                sys.stdout.write(f'"{key}":"{value}"')
            sys.stdout.write("}\n")
            sys.stdout.write("<<<<>>>>\n")

    def execute_lq_query(self, query, address):
        family = socket.AF_INET if type(address) is tuple else socket.AF_UNIX
        sock = socket.socket(family, socket.SOCK_STREAM)
        sock.connect(address)

        sock.sendall(query.encode("utf-8") + b"\n")
        sock.shutdown(socket.SHUT_WR)

        chunks = []
        while True:
            data = sock.recv(4096).decode("utf-8")
            if data == "":
                break
            chunks.append(data)
        sock.close()
        reply = "".join(chunks)
        return json.loads(reply)

    def filter_hosts(self, data, *, key=None, value=None):
        if key is None and value is None:
            raise ValueError("Mindestens key_pattern oder value_pattern angeben")

        result = {}
        for host, items in data.items():
            matches = []
            for d in items:
                if key is not None and key not in d:
                    continue
                if value is not None and value not in d.values():
                    continue
                matches.append(d)

            if matches:
                result[host] = matches
        return result

    def compile_regex(self, regex) -> None:
        try:
            re.compile(regex)
        except re.error as e:
            raise ValueError(f"Ungültiger Regex: {e}")

    def regex_filter_hosts(
        self,
        data: Dict[str, list[dict]],
        *,
        key_pattern: str | None = None,
        value_pattern: str | None = None,
        flags=re.IGNORECASE,
    ) -> Dict[str, list[tuple[dict, re.Match]]]:

        if key_pattern is None and value_pattern is None:
            raise ValueError("Mindestens key_pattern oder value_pattern angeben")

        key_re = re.compile(key_pattern, flags) if key_pattern else None
        value_re = re.compile(value_pattern, flags) if value_pattern else None

        result: Dict[str, list[tuple[dict, re.Match]]] = {}

        for host, items in data.items():
            matches: list[tuple[dict, re.Match]] = []

            for d in items:
                for k, v in d.items():
                    key_match = key_re.search(k) if key_re else None
                    value_match = value_re.search(str(v)) if value_re else None

                    if key_re and not key_match:
                        continue
                    if value_re and not value_match:
                        continue

                    match = value_match or key_match
                    if match:
                        matches.append((d, match))
                    break

            if matches:
                result[host] = matches

        return result

    def apply_match_groups(self, target: str, match: re.Match) -> str:
        if match is None:
            return target
        result = target

        for k, v in match.groupdict().items():
            result = result.replace(f"{{{k}}}", v)

        for i, g in enumerate(match.groups(), start=1):
            result = result.replace(f"\\{i}", g)

        return result

    def main(self, debug=False):
        cmk_hosts_with_labels = {}

        for arg_label in self.labels:
            source, target = arg_label
            which_is_regex = self.get_key_value_is_regex(source)
            query = "GET hosts\nColumns: name labels\nOutputFormat: json\n"
            hostlabel_response = self.execute_lq_query(query, self.address)
            for entry in hostlabel_response:
                cmk_hosts_with_labels[entry[0]] = []
                for label_key in entry[1]:
                    label_dict = {}
                    label_dict[label_key] = entry[1][label_key]
                    cmk_hosts_with_labels[entry[0]].append(label_dict)
            lq_source_label_key, lq_source_label_value = self.format_label_for_lq(
                source, which_is_regex
            )
            resulted_hosts_filter = None

            if which_is_regex is WhichIsRegex.NO:
                resulted_hosts_filter = self.filter_hosts(
                    cmk_hosts_with_labels,
                    key=lq_source_label_key,
                    value=lq_source_label_value,
                )
            elif which_is_regex is WhichIsRegex.KEY:
                filtered_data = self.filter_hosts(
                    cmk_hosts_with_labels, value=lq_source_label_value
                )
                self.compile_regex(lq_source_label_key)
                if lq_source_label_key != ".*":
                    resulted_hosts_filter = self.regex_filter_hosts(
                        filtered_data, key_pattern=lq_source_label_key
                    )
                else:
                    resulted_hosts_filter = filtered_data
            elif which_is_regex is WhichIsRegex.VALUE:
                filtered_data = self.filter_hosts(
                    cmk_hosts_with_labels, key=lq_source_label_key
                )
                self.compile_regex(lq_source_label_value)
                if lq_source_label_value != ".*":
                    resulted_hosts_filter = self.regex_filter_hosts(
                        filtered_data, value_pattern=lq_source_label_value
                    )
                else:
                    resulted_hosts_filter = filtered_data
            elif which_is_regex is WhichIsRegex.KEYANDVALUE:
                self.compile_regex(lq_source_label_key)
                self.compile_regex(lq_source_label_value)

                if lq_source_label_key != ".*" and lq_source_label_value != ".*":
                    resulted_hosts_filter = self.regex_filter_hosts(
                        cmk_hosts_with_labels,
                        key_pattern=lq_source_label_key,
                        value_pattern=lq_source_label_value,
                    )
                elif lq_source_label_key == ".*" and lq_source_label_value != ".*":
                    resulted_hosts_filter = self.regex_filter_hosts(
                        cmk_hosts_with_labels,
                        value_pattern=lq_source_label_value,
                    )
                elif lq_source_label_key != ".*" and lq_source_label_value == ".*":
                    resulted_hosts_filter = self.regex_filter_hosts(
                        cmk_hosts_with_labels,
                        key_pattern=lq_source_label_key,
                    )
                elif lq_source_label_key == ".*" and lq_source_label_value == ".*":
                    resulted_hosts_filter = cmk_hosts_with_labels

            if resulted_hosts_filter is not None and not debug:
                for host, matches in resulted_hosts_filter.items():
                    for item in matches:
                        if isinstance(item, tuple):
                            label_dict, match = item
                        else:
                            label_dict = item
                            match = None
                        resolved_target = self.apply_match_groups(target, match)
                        self.create_piggyback(host=host, new_label=resolved_target)

            else:
                debug_output = []

                if resulted_hosts_filter:
                    for host, matches in resulted_hosts_filter.items():
                        for item in matches:
                            if isinstance(item, tuple):
                                label_dict, match = item
                            else:
                                label_dict = item
                                match = None

                            source_key, source_value = next(iter(label_dict.items()))
                            resolved_target = self.apply_match_groups(target, match)

                            debug_output.append(
                                {
                                    "host": host,
                                    "source_label": f"{source_key}:{source_value}",
                                    "key_regex": lq_source_label_key
                                    if which_is_regex
                                    in (WhichIsRegex.KEY, WhichIsRegex.KEYANDVALUE)
                                    else None,
                                    "value_regex": lq_source_label_value
                                    if which_is_regex
                                    in (WhichIsRegex.VALUE, WhichIsRegex.KEYANDVALUE)
                                    else None,
                                    "target_label": resolved_target,
                                }
                            )

                pprint(debug_output)


def parse_arguments(argv):
    parser = argparse.ArgumentParser(description=__doc__)

    parser.add_argument("--labels", help="", required=True)
    parser.add_argument(
        "--debug",
        help="",
        action="store_true",
        required=False,
    )

    args = parser.parse_args()
    return args


def main(argv=None):
    args = parse_arguments(sys.argv[1:])
    aux_host_labels = AUX_Host_Labels(args.labels)
    aux_host_labels.main(args.debug)
    aux_host_labels.print_piggyback()


if __name__ == "__main__":
    sys.exit(main())
