# Copyright (C) 2025 Siemens
#
# SPDX-License-Identifier: MIT

from collections import namedtuple
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from debian.deb822 import Dsc, Deb822, Sources, Packages
from debian.debian_support import Version
import logging
from pathlib import Path

from ..util.compression import (
    CompressionToolMissing,
    find_compressed_file_variants,
    stream_compressed_file,
)
from ..dpkg.package import BinaryPackage, SourcePackage
from .. import HAS_PYTHON_APT

logger = logging.getLogger(__name__)


@dataclass
class ExtendedStates:
    """
    The apt extended states encode information if a package is manually
    installed or installed via a dependency only.
    """

    PackageFilter = namedtuple("BinaryPackage", "name arch")

    auto_installed: set[tuple[str, str]]

    def is_manual(self, name: str, arch: str) -> bool:
        """True if package is explicitly installed"""
        return (name, arch) not in self.auto_installed

    @classmethod
    def from_file(
        cls, file: str | Path, filter_fn: Callable[[PackageFilter], bool] | None = None
    ) -> "ExtendedStates":
        """Factory to create instance from the apt extended states file"""
        auto_installed = set()
        with open(Path(file)) as f:
            for s in Deb822.iter_paragraphs(f, use_apt_pkg=HAS_PYTHON_APT):
                name = s.get("Package")
                arch = s.get("Architecture")
                if s.get("Auto-Installed") != "1":
                    continue
                if (filter_fn is None) or (filter_fn(cls.PackageFilter(name, arch))):
                    auto_installed.add((name, arch))

        return cls(auto_installed=auto_installed)


@dataclass
class Repository:
    """Represents a debian repository as cached by apt."""

    release_file: Path
    origin: str | None
    codename: str | None
    architectures: list[str]
    components: list[str] | None = None
    version: Version | None = None
    description: str | None = None

    BinaryPackageFilter = namedtuple("BinaryPackage", "name arch version")
    SourcePackageFilter = namedtuple("SourcePackage", "name version")

    @classmethod
    def from_apt_cache(cls, lists_dir: str | Path) -> Iterable["Repository"]:
        """Create repositories from apt lists directory."""
        for entry in Path(lists_dir).iterdir():
            if entry.name.endswith("Release"):
                with open(entry) as f:
                    repo = Deb822(f)
                origin = repo.get("Origin")
                codename = repo.get("Codename")
                version = repo.get("Version")
                architectures = repo.get("Architectures", "").split()
                components = repo.get("Components")
                description = repo.get("Description")
                logger.info(f"Found apt lists cache repository: {entry}")
                if not len(architectures):
                    logger.error(f"Repository does not specify 'Architectures', ignoring: {entry}")
                    continue
                yield Repository(
                    release_file=entry,
                    origin=origin,
                    codename=codename,
                    version=Version(version) if version else None,
                    architectures=architectures,
                    components=components.split() if components else None,
                    description=description,
                )

    @classmethod
    def _safe_srcpkg_filter(
        cls, p: Packages, filter_fn: Callable[[SourcePackageFilter], bool] | None
    ) -> bool:
        try:
            package = p["Package"]
            version = p["Version"]
        except KeyError as e:
            logger.warning(f"skipping invalid source package: {e}")
            return False
        if filter_fn is None:
            return True
        return filter_fn(cls.SourcePackageFilter(package, version))

    @classmethod
    def _safe_binpkg_filter(
        cls, p: Packages, filter_fn: Callable[[BinaryPackageFilter], bool] | None
    ) -> bool:
        try:
            package = p["Package"]
            arch = p["Architecture"]
            version = p["Version"]
        except KeyError as e:
            logger.warning(f"skipping invalid binary package: {e}")
            return False
        if filter_fn is None:
            return True
        return filter_fn(cls.BinaryPackageFilter(package, arch, version))

    @classmethod
    def _make_srcpkgs(
        cls,
        sources: Iterable[Packages],
        filter_fn: Callable[[SourcePackageFilter], bool] | None = None,
    ) -> Iterable[SourcePackage]:
        _sources = filter(lambda p: cls._safe_srcpkg_filter(p, filter_fn), sources)
        for source in _sources:
            try:
                yield SourcePackage.from_deb822(Dsc(source))
            except KeyError as e:
                logger.error("control file in is not valid deb822, skip entry")
                logger.debug(e)

    @classmethod
    def _make_binpkgs(
        cls,
        packages: Iterable[Packages],
        filter_fn: Callable[[BinaryPackageFilter], bool] | None = None,
    ) -> Iterable[BinaryPackage]:
        _pkgs = filter(lambda p: cls._safe_binpkg_filter(p, filter_fn), packages)
        for pkg in _pkgs:
            try:
                yield BinaryPackage.from_deb822(pkg)
            except KeyError as e:
                logger.error("control file in is not valid deb822, skip entry")
                logger.debug(e)

    @classmethod
    def _parse_sources(
        cls, sources_file: str, srcpkg_filter: Callable[[SourcePackageFilter], bool] | None = None
    ) -> Iterable["SourcePackage"]:
        sources_path = Path(sources_file)
        try:
            if sources_path.exists():
                with open(sources_path) as f:
                    logger.debug(f"Parsing apt cache source packages: {sources_file}")
                    sources_raw = Packages.iter_paragraphs(f, use_apt_pkg=HAS_PYTHON_APT)
                    yield from Repository._make_srcpkgs(sources_raw, srcpkg_filter)
            else:
                compressed_variant = find_compressed_file_variants(sources_path)[0]
                content = stream_compressed_file(compressed_variant)
                logger.debug(f"Parsing apt cache source packages: {sources_file}")
                # TODO: in python-debian >= 1.0.0 it is possible to directly
                # pass the filename of a compressed file when using apt_pkg
                sources_raw = Packages.iter_paragraphs(content, use_apt_pkg=False)
                yield from Repository._make_srcpkgs(sources_raw, srcpkg_filter)
        except CompressionToolMissing as e:
            logger.warning(f'{e}: skipping path "{compressed_variant}"')
        except (FileNotFoundError, IndexError, RuntimeError):
            logger.debug(f"Missing apt cache sources: {sources_file}")

    @classmethod
    def _parse_packages(
        cls, packages_file: str, binpkg_filter: Callable[[BinaryPackageFilter], bool] | None = None
    ) -> Iterable[BinaryPackage]:
        packages_path = Path(packages_file)
        try:
            if packages_path.exists():
                with open(packages_path) as f:
                    packages_raw = Packages.iter_paragraphs(f, use_apt_pkg=HAS_PYTHON_APT)
                    logger.debug(f"Parsing apt cache binary packages: {packages_file}")
                    yield from Repository._make_binpkgs(packages_raw, binpkg_filter)
            else:
                compressed_variant = find_compressed_file_variants(packages_path)[0]
                content = stream_compressed_file(compressed_variant)
                # TODO: in python-debian >= 1.0.0 it is possible to directly
                # pass the filename of a compressed file when using apt_pkg
                packages_raw = Packages.iter_paragraphs(content, use_apt_pkg=False)
                logger.debug(f"Parsing apt cache binary packages: {packages_file}")
                yield from Repository._make_binpkgs(packages_raw, binpkg_filter)
        except CompressionToolMissing as e:
            logger.warning(f'{e}: skipping path "{compressed_variant}"')
        except (FileNotFoundError, IndexError, RuntimeError):
            logger.debug(f"Missing apt cache packages: {packages_file}")

    @property
    def repo_base(self):
        return "_".join(str(self.release_file).split("_")[:-1])

    def sources(
        self, filter_fn: Callable[[SourcePackageFilter], bool] | None = None
    ) -> Iterable[SourcePackage]:
        """Get all source packages from this repository."""
        if self.components:
            for component in self.components:
                sources_file = "_".join([self.repo_base, component, "source", "Sources"])
                yield from self._parse_sources(sources_file, filter_fn)
        else:
            sources_file = "_".join([self.repo_base, "source", "Sources"])
            return self._parse_sources(sources_file, filter_fn)

    def binpackages(
        self,
        filter_fn: Callable[[BinaryPackageFilter], bool] | None = None,
        ext_states: ExtendedStates = ExtendedStates(set()),
    ) -> Iterable[BinaryPackage]:
        """Get all binary packages from this repository"""
        if self.components:
            for component in self.components:
                for arch in self.architectures:
                    packages_file = "_".join(
                        [self.repo_base, component, f"binary-{arch}", "Packages"]
                    )
                    for p in self._parse_packages(packages_file, filter_fn):
                        p.manually_installed = ext_states.is_manual(p.name, p.architecture)
                        yield p
        else:
            for arch in self.architectures:
                packages_file = "_".join([self.repo_base, f"binary-{arch}", "Packages"])
                for p in self._parse_packages(packages_file, filter_fn):
                    p.manually_installed = ext_states.is_manual(p.name, p.architecture)
                    yield p
