import inspect
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from .analyse_measurements import AnalyseMeasurements
from .config import DATABASE, FILE_SUFFIXES, MEASUREMENTS_DIR
from .measurement import Measurement
from .measurement_manager import MeasurementManager
from .meta import Meta
logger = logging.getLogger(__name__)
[docs]@dataclass(kw_only=True)
class DBentry:
    "A XRD measurement database entry template."
    measurement_id: str
    """str: The ID of the measurement."""
    sample: str
    """str: The ID of the sample that was measured."""
    compound: str
    """str: The formula of the compound that was measured."""
    description: str
    """str: A description of the measurement."""
    comment: str
    """str: A comment about the measurement."""
    ht_mode: bool
    """bool: Whether the measurement was taken in high temperature mode."""
    processing_state: str = None
    """str: The processing state of the measurement, expression must be predefined."""
    xrd_datetime: str = None
    """str: The date and time at which the measurement was taken."""
    date_added: str
    """str: The date and time at when the measurement was added to the DB.""" 
[docs]@dataclass
class MeasurementDatabase:
    """XRD Measurement database
    Args:
        db_file (str): Path to the database file (CSV).
        measurements_dir (str): Path to measurements directory.
    """
    db_file: str = DATABASE
    measurements_dir: str = MEASUREMENTS_DIR
    def __post_init__(self):
        self._m_manager = MeasurementManager(self.measurements_dir)
        self.entries = self._load_data()
        self.update_db_file()
    def _load_data(self) -> pd.DataFrame:
        if not os.path.isfile(self.db_file):
            return pd.DataFrame()
        return pd.read_csv(
            self.db_file, index_col="measurement_id", dtype={"measurement_id": str}
        ).sort_index(ascending=False)
    def _validate_new_id(self, new_id: str) -> None:
        """Ensure new_id does not exist in DB file."""
        if new_id in self.entries.index:
            raise ValueError(f"Measurement ID {new_id!r} already existing in DB.")
[docs]    def add_measurement(self, meta_obj: Meta, to_file: bool = True) -> None:
        """Add a new measurement to the database.
        Attributes
        ----------
        meta_obj :
            Meta object of measurement to be added to database.
        to_file :
            Flag indicating whether the updated DB is written to its `csv` file.
        """
        args_class = list(inspect.signature(DBentry).parameters)
        kwargs = {
            k: meta_obj.__dict__[k] for k in args_class if k in meta_obj.__dict__.keys()
        }
        kwargs["date_added"] = datetime.now().isoformat()
        if self.empty:
            self.entries = pd.DataFrame([kwargs]).set_index("measurement_id")
        else:
            self._validate_new_id(kwargs["measurement_id"])
            dict_df = pd.DataFrame([kwargs]).set_index("measurement_id")
            self.entries = pd.concat([self.entries, dict_df])
        if to_file:
            self.to_file() 
[docs]    def to_file(self) -> None:
        """Write the database content to its CSV file."""
        self.entries.to_csv(self.db_file)
        logger.debug(f"Database written to {self.db_file!r}") 
    @property
    def empty(self) -> bool:
        """Returns True if no measurements are registered."""
        if len(self.entries) > 0:
            return False
        return True
[docs]    def list_measurements(self) -> str:
        """Return string with table of database entries."""
        if self.empty:
            return "No measurements registered in database"
        return str(self.entries) 
[docs]    def list_measurements_for_db_key(self, column: str, value: str) -> list[str]:
        """Returns list of measurement IDs for provided DB column and value(s).
        Method can be used to filter the database for a specific data subset,
        e.g. measurements corresponding to a certain processing state.
        Attributes
        ----------
        column :
            Column where value(s) exists (e.g. `processing_state`).
        value :
            Value(s) to filter for (e.g. `refined`).
        """
        df = (
            self.entries.reset_index()
            .sort_values(["sample", "measurement_id"])
            .set_index("measurement_id")
        )
        return list(df[df[column] == value].index.values) 
[docs]    def list_compounds(self) -> list[str]:
        """Returns a list with unique compounds registered in the DB (sorted)."""
        return list(self.entries["compound"].dropna().sort_values().unique()) 
[docs]    def list_samples(self) -> list[str]:
        """Returns a list with sample IDs registered in the DB (sorted)."""
        return list(self.entries["sample"].dropna().sort_values().unique()) 
[docs]    def get_measurement_for_id(self, measurement_id: str) -> Measurement:
        """Returns a measurement object for the provided measurement ID."""
        return self._m_manager.get_measurement(measurement_id) 
[docs]    def get_measurement_for_sample(self, sample: str) -> Measurement:
        """Returns a measurement object for the provided sample ID."""
        measurement_ids = self.list_measurements_for_db_key("sample", sample)
        if len(measurement_ids) == 1:
            return self._m_manager.get_measurement(measurement_ids[0])
        elif len(measurement_ids) == 0:
            try:
                return self._m_manager.get_measurement(sample)
            except IndexError:
                raise IndexError(f"No measurement registered for sample '{sample}'.")
        else:
            raise ValueError(
                f"Multiple measurements registered for '{sample}'."
                + f"\n            Provide one of the measurement IDs instead:"
                + f"\n            {measurement_ids}"
            ) 
[docs]    def get_measurements_for_compound(self, compound: str) -> AnalyseMeasurements:
        """
        Returns an AnalyseMeasurements object with measurement of the provided compound.
        """
        measurement_ids = self.list_measurements_for_db_key("compound", compound)
        measurements = [self._m_manager.get_measurement(i) for i in measurement_ids]
        return AnalyseMeasurements(measurements) 
[docs]    def update_db_file(self, meta_suffix=FILE_SUFFIXES["meta"]) -> None:
        """
        Search for all meta data files in measurements subdirectories, and update
        database file with the meta information of unregistered measurements.
        """
        counter = 0
        cwd = os.getcwd()
        os.chdir(self.measurements_dir)
        logger.debug(f"Searching data in {self.measurements_dir!r}...")
        meta_objs = []
        for root, _, files in os.walk(self.measurements_dir):
            for file in files:
                if file.endswith(meta_suffix + ".json"):
                    meta_obj = Meta.from_json(os.path.join(root, file))
                    if meta_obj.measurement_id not in self.entries.index:
                        meta_objs.append(meta_obj)
        os.chdir(cwd)
        if len(meta_objs) == 0:
            logger.info("No unregistered measurement found.")
            return None
        logger.debug(f"Found {len(meta_objs)} unregistered measurements.")
        count = 0
        for meta_obj in tqdm(meta_objs, desc="Adding measurements"):
            self.add_measurement(meta_obj, to_file=False)
            count += 1
        if count == 0:
            logger.info("No measurement added.")
            return None
        logger.info(f"DB file '{self.db_file}' updated with {count} measurements.")
        self.to_file()