#!/usr/bin/env python3

"""Janky diffing."""

import argparse
import dataclasses
from dataclasses import dataclass
from functools import cached_property
import logging
import re
import subprocess
from typing import Collection, Mapping, Optional, Sequence, Tuple
import difflib

_logger = logging.getLogger(__name__)

RELEVANT_ROOTS = (
    '/etc',
    '/opt',
    '/usr',
)
EXCLUDED_PATHS = (
    '/usr/lib/modules',
    '/usr/lib/firmware',
    '/usr/include',
    '/usr/share/man',
    '/opt/keycloak',
    '/etc/letsencrypt',
    '/etc/pacman.d/mirrorlist',
    '/etc/pacman.d/mirrorlist.bak',
    '/etc/pacman.d/gnupg',
    '/etc/mot'
)


@dataclass(frozen=True)
class SshConfig:
    hostname: str
    username: str = 'root'
    private_key: Optional[str] = None

    @cached_property
    def ssh_command(self) -> Tuple[str, ...]:
        if self.private_key is None:
            return 'ssh', f'{self.username}@{self.hostname}'
        return 'ssh', '-i', self.private_key, f'{self.username}@{self.hostname}'


@dataclass(frozen=True)
class Stat:
    name: str
    link: Optional[str]
    file_type: str
    mode: str
    user: str
    group: str

    _re = re.compile(r"'(?P<name>.+?)'"
                     r"( -> '(?P<link>.+?))?"
                     r" (?P<file_type>.+)"
                     r" (?P<mode>\d{3,4})"
                     r" (?P<user>[^ ]+)"
                     r" (?P<group>[^ ]+)")

    @classmethod
    def from_line(cls, line: bytes):
        match = cls._re.match(line.decode())
        if not match:
            raise ValueError(f'Invalid line: {line}')
        return cls(**match.groupdict())

    def replace(self, **overrides):
        kwargs = dataclasses.asdict(self)
        kwargs.update(overrides)
        return type(self)(**kwargs)


class DistDiff:

    def __init__(self,
                 ssh_config: SshConfig,
                 paths: Sequence[str] = RELEVANT_ROOTS,
                 excluded_paths: Sequence[str] = EXCLUDED_PATHS):
        self.ssh_config = ssh_config
        self.paths = paths
        self.excluded_paths = excluded_paths

    def run(self, *command: str, **kwargs):
        # Ghetto escapes, because shlex breaks stuff
        for char in (' ', ';', '(', ')'):
            command = tuple(c.replace(char, f'\\{char}') for c in command)
        whole_command = self.ssh_config.ssh_command + command
        _logger.debug('Executing %s', ' '.join(whole_command))
        p = subprocess.run(whole_command, capture_output=True, **kwargs)
        if p.stderr:
            for line in p.stderr.decode().splitlines():
                _logger.error(line)
            raise RuntimeError('stderr populated')
        return p.stdout

    def find(self, *types: str) -> bytes:
        command = 'find', *self.paths

        # Type filter
        if len(types) > 1:
            command = command + ('(',)
            for type_ in types:
                if command[-1] != '(':
                    command = command + ('-o',)
                command = command + ('-type', type_)
            command = command + (')',)
        elif len(types) == 1:
            command = command + ('-type', types[0])

        # Prune filter
        if self.excluded_paths:
            command = command + ('(',)
            for path in self.excluded_paths:
                command = command + ('-path', path, '-prune', '-o')

        command = command + ('-print0',)

        if self.excluded_paths:
            command = command + (')',)

        return self.run(*command)

    def stat(self) -> Sequence[Stat]:
        stats = self.run('xargs', '-0',
                         'stat', '-c', '%N %F %a %U %G',
                         input=self.find())
        return [Stat.from_line(line) for line in stats.splitlines()]

    def md5(self) -> Mapping[str, str]:
        # TODO: Investigate how to make this faster with GNU parallel
        md5s = self.run('xargs', '-0',
                        'md5sum', '-z',
                        input=self.find('f'))
        md5s_d = {}
        for line in md5s.decode().split('\x00'):
            if not line:
                continue
            md5, filename = line.split(maxsplit=1)
            md5s_d[filename] = md5
        return md5s_d

    def file_type(self) -> Mapping[str, Collection[str]]:
        # TODO: Investigate how to make this faster with GNU parallel
        types = self.run('xargs', '-0',
                         'file', '-0', '-0',
                         input=self.find('f'))

        types_list = types.decode().split('\x00')
        types_d = {}
        for filename, tags in zip(types_list[::2], types_list[1::2]):
            types_d[filename] = set(tags.split(', '))
        return types_d

    def get_file(self, file: str):
        return self.run('cat', file).decode()


def compare(host1, host2, paths, exclude, mode_ignore):
    dd1 = DistDiff(SshConfig(host1), paths, exclude)
    dd2 = DistDiff(SshConfig(host2), paths, exclude)

    dd1_stat = {s.name: s for s in dd1.stat()}
    dd2_stat = {s.name: s for s in dd2.stat()}

    dd1_md5 = dd1.md5()
    dd2_md5 = dd2.md5()

    dd1_type = dd1.file_type()
    dd2_type = dd2.file_type()

    dd1_only = dd1_stat.keys() - dd2_stat.keys()
    dd2_only = dd2_stat.keys() - dd1_stat.keys()

    if dd1_only:
        print(f'Files only on {host1}:')
        for file in sorted(dd1_only):
            print(f'  {dd1_stat[file]}')
        print('')

    if dd2_only:
        print(f'Files only on {host2}:')
        for file in sorted(dd2_only):
            print(f'  {dd2_stat[file]}')
        print('')

    print(f'Stat differences between -{host1} and +{host2}')
    for file in sorted(dd1_stat.keys() & dd2_stat.keys()):
        if file.startswith(tuple(mode_ignore)):
            if dd1_stat[file].replace(mode=None) != dd2_stat[file].replace(mode=None):
                print(f'  -{dd1_stat[file]}')
                print(f'  +{dd2_stat[file]}')
        elif dd1_stat[file] != dd2_stat[file]:
            print(f'  -{dd1_stat[file]}')
            print(f'  +{dd2_stat[file]}')

    print(f'Content diffs between {host1} and {host2}')
    for file in dd1_md5.keys() & dd2_md5.keys():
        if dd1_md5[file] != dd2_md5[file]:
            if (any(tag.endswith(('text', 'text executable')) for tag in dd1_type[file])
                    and any(tag.endswith(('text', 'text executable')) for tag in dd2_type[file])):
                # Do diff
                left = dd1.get_file(file).splitlines()
                right = dd2.get_file(file).splitlines()
                for diff in difflib.unified_diff(left, right,
                                                 fromfile=f'{host1}:{file}',
                                                 tofile=f'{host2}:{file}'):
                    print(diff)
            else:
                print(file)
                print(f'  -md5:{dd1_md5[file]}')
                print(f'  +md5:{dd2_md5[file]}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser('dist_diff')
    parser.add_argument('host1')
    parser.add_argument('host2')
    parser.add_argument('--paths', nargs='+', default=RELEVANT_ROOTS)
    parser.add_argument('--exclude', nargs='+', default=EXCLUDED_PATHS)
    parser.add_argument('--mode-ignore', nargs='+', default=['/boot'])
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)
    compare(args.host1, args.host2, args.paths, args.exclude, args.mode_ignore)