import datetime
import logging
import os
import shutil
from dataclasses import dataclass
import jinja2
import matplotlib.pyplot as plt
import pandas as pd
from . import plugin_loader, refinement_interface_factory, utils
from .config import ENCODING, MEASUREMENTS_DIR, PROCESSING_STATES, XRD_DATA_COLUMNS
from .meta import Meta
from .paths import MeasurementPaths
from .refinement import RefinedPhase, RefinementResult
from .refinement_interface import AppNotInstalledError
logger = logging.getLogger(__name__)
[docs]def get_data(
    file_path: str,
    col_angle: str = XRD_DATA_COLUMNS["angle"],
    col_intensity: str = XRD_DATA_COLUMNS["int_abs"],
    encoding: str = ENCODING,
) -> pd.Series:
    """Load XRD data from file.
    Returns:
        pd.Series: The series contains the *x*/*y* data of the measurement. Its
        index represents the 2θ angle in °.
    """
    utils.ensure_file_exists(file_path)
    df = pd.read_csv(
        file_path,
        index_col=col_angle,
        encoding=encoding,
    )
    return df[col_intensity] 
[docs]class NoRefinerSetError(Exception):
    """Error that is raised if no refinement interface plugin is set.
    Args:
        message (str): Error message
    """
    def __init__(
        self,
        message="No refinement interface plugin defined for this measurement. Use `set_refinement_interface('name')` to set one.",
    ):
        self.message = message
        super().__init__(self.message) 
[docs]@dataclass
class Measurement:
    """XRD measurement class.
    Parameter:
        :paths: MeasurementPaths object of the measurements.
        :meta:  Metadata object of the measurement.
        :data:  Pandas Series containing the *x*/*y* data of the measurement
                as imported. The index represents the 2θ angle.
    """
    paths: MeasurementPaths
    meta: Meta
    data: pd.Series = None
[docs]    @classmethod
    def from_id(cls, measurement_id: str, measurements_dir: str = MEASUREMENTS_DIR):
        """Alternative constructor to initiate a Measurement instance.
        Args:
            measurement_id (str): ID of the measurement to be loaded.
            measurements_dir (str): Path to the measurements directory.
        """
        paths = MeasurementPaths(measurements_dir, measurement_id)
        meta = Meta.from_json(paths.file_meta)
        data = get_data(paths.file_data)
        return Measurement(paths=paths, meta=meta, data=data) 
    def __post_init__(self):
        self._logger = self._get_measurement_logger()
        if os.path.getsize(self.paths.file_log) == 0:
            self._logger.info("Created measurement entry")
        self._refinement_interface = None
        logger.debug(f"Loaded measurement {self.meta.measurement_id!r}")
    def _get_measurement_logger(self):
        """Get logger to track measurement log."""
        # utils.make_dirs(self.paths.file_log)
        formatter = logging.Formatter("%(asctime)s : %(name)s : %(message)s")
        handler = logging.FileHandler(self.paths.file_log)
        handler.setFormatter(formatter)
        measurement_logger = logging.getLogger(self.meta.measurement_id)
        measurement_logger.setLevel(logging.INFO)
        measurement_logger.addHandler(handler)
        return measurement_logger
    def _check_data(self):
        """Check if XRD data are registered for the measurement.
        Raises:
            ValueError: If no XRD data are registered for the measurement.
        """
        if not self.has_data:
            raise ValueError(
                f"No XRD data registered for measurement {self.meta.measurement_id!r}."
            )
    def _data_to_csv(self, encoding: str = ENCODING) -> None:
        """Write the XRD data data to the csv file.
        Raises:
            ValueError: If no XRD data are registered for the measurement.
        """
        self._check_data()
        self.data.to_csv(self.paths.file_data, encoding=encoding)
        logger.debug(f"XRD data written to: {self.paths.file_data}")
    def _data_to_plot(self, ax: plt.Axes, norm=True) -> None:
        """Add data of XRD data to Axes.
        Raises:
            ValueError: If no XRD data are registered for the measurement.
        """
        self._check_data()
        if norm:
            data = self.data_norm
        else:
            data = self.data
        ax.plot(data.index, data)
        ax.set_ylabel(data.name)
    def _get_plot_window_title(self) -> str:
        """Returns string with measurement ID and sample if defined."""
        window_title = f"{self.meta.mode} measurement: {self.meta.measurement_id}"
        if self.meta.sample is not None:
            window_title += f" ({self.meta.sample})"
        return window_title
    # TODO: Add a flag which allows to set the operator as author #
[docs]    def create_protocol(
        self,
        author: str,
        template: str,
        encoding: str = ENCODING,
    ) -> None:
        """
        Create a protocol document for a XRD measurement.
        Args:
            param author (str): Author of the measurement protocol.
            param template (str): Path to a protocol document template file.
            param encoding (str): Encoding of the protocol file.
        """
        # Load template
        loader = jinja2.FileSystemLoader(os.path.dirname(template))
        env = jinja2.Environment(loader=loader)
        template = env.get_template(os.path.basename(template))
        # Assign field values
        values = {
            "measurement_id": self.meta.measurement_id,
            "author": author,
            "date": datetime.date.today(),
        }
        if self.meta.sample is not None:
            values["sample"] = self.meta.sample
        # Create and write content
        notes_content = template.render(**values)
        with open(self.paths.file_protocol, "w", encoding=encoding) as fobj:
            fobj.write(notes_content)
            self._logger.info("Created measurement protocol") 
    @property
    def data_norm(self):
        """XRD data normalised to maximum intensity.
        Returns:
            pd.Series: Series containing the normalised x/y data of the measurement.
            The index represents the 2θ angle.
        Raises:
            ValueError: If no XRD data are registered for the measurement.
        """
        self._check_data()
        data = self.data / self.data.max()
        data.name = XRD_DATA_COLUMNS["int_norm"]
        return data
[docs]    def get_cif_files(self, to_file=True) -> None:
        """Copy CIF file of refined phase(s) from refinement directory to results subdirectory."""
        try:
            input_cif = self._refinement_interface.get_cif_files()
        except ValueError as e:
            logger.debug(e)
            return None
        for phase, source_file in input_cif.items():
            if os.path.isfile(source_file) and to_file:
                destination_file = self.paths.get_cif_file_path(phase)
                utils.make_dirs(destination_file)
                # Copy CIF file only if it doesn't exists
                with open(source_file, "rb") as f1, open(destination_file, "rb") as f2:
                    if f1.read() == f2.read():
                        logger.debug(
                            f"No CIF file added for refined {phase!r} phase. The files exists already."
                        )
                    else:
                        shutil.copyfile(source_file, destination_file)
                        logger.info(
                            f"Added CIF file for refined {phase!r} phase: {os.path.basename(destination_file)!r}"
                        ) 
[docs]    def get_processing_state(self) -> str:
        """Get the measurements' data processing state.
        Returns:
            str: Current processing state, one of the following options:
              - ``refined`` if a refined data file exists.
              - ``None`` is no option listed above is applicable.
        """
        if self.is_refined:
            return PROCESSING_STATES["refined"]
        return None 
[docs]    def get_refined_data(self, encoding=ENCODING, to_file: bool = False) -> None:
        """Get the refinement results as pandas directory.
        The method requires a refinement interface plugin which provides the data.
        Args:
            encoding (str): Encoding used in refined data file if it gets written.
            to_file (bool): Refined data are written to the data directory if ``True``.
        """
        try:
            df = self._refinement_interface.get_refined_data(
                i_calc=XRD_DATA_COLUMNS["int_calc"],
                i_bg=XRD_DATA_COLUMNS["int_bg"],
            )
        except FileNotFoundError as e:
            logger.debug(e)
            return None
        if to_file:
            utils.make_dirs(self.paths.file_refined_data)
            df.to_csv(self.paths.file_refined_data, encoding=encoding)
            logger.info(
                f"Wrote refined XRD data to: {os.path.basename(self.paths.file_refined_data)!r}"
            )
            self.set_processing_state(to_file=to_file)
        return df 
[docs]    def get_refined_phase(self, phase: str) -> RefinedPhase:
        """Get a refined phase object.
        Requires the presence of a CIF file for the phase of interest.
        Args:
            phase (str): Name of the phase of interest
        Returns:
            RefinedPhase: Object containing the results for the specified
            refined phase.
        """
        file_path = self.paths.get_cif_file_path(phase)
        utils.ensure_file_exists(file_path=file_path)
        return RefinedPhase(file_path) 
[docs]    def get_refinement_result(self) -> RefinementResult:
        """Get a refinement result object.
        Returns:
            RefinementResult: Object containing the results of the refinement.
        """
        file_path = self.paths.file_refinement_result
        utils.ensure_file_exists(file_path=file_path)
        return RefinementResult.from_json(file_path) 
    @property
    def has_data(self) -> bool:
        """Check if XRD measurement data are available.
        Returns:
            bool: True if XRD data are available, and False if not.
        """
        if self.data is None:
            return False
        return True
    @property
    def has_refiner(self) -> str:
        """Checks if a refinement interface plugin is set for this measurement.
        The refinement plugin can be defined via the method 'set_refinement_interface'.
        Returns:
            bool: True if a refinement interface plugin in set, False if not.
        """
        if self._refinement_interface is None:
            return False
        return True
    @property
    def is_refined(self) -> str:
        """Flag that indicated wheter the measurement is refined.
        Returns:
            bool: True if refined data are present in the data subdirectory of the
                measurement, False otherwise.
        """
        return os.path.isfile(self.paths.file_refined_data)
[docs]    def plot(self, norm: bool = False, window_title: str = None):
        """Plot the XRD data.
        Args:
            norm (bool): Plot the data normalised to the maximum intensity if True.
            window_title (str): Title for the matplotlib window that will be created.
        Raises:
            ValueError: If no XRD data are registered for the measurement.
        """
        self._check_data()
        logger.debug(f"Plotting measurement {self.meta.measurement_id!r}...")
        if window_title is None:
            window_title = self._get_plot_window_title()
        fig_kwargs = {"tight_layout": True, "num": window_title}
        fig, ax = plt.subplots(**fig_kwargs)
        self._data_to_plot(ax, norm=norm)
        ax.set_xlabel(self.data.index.name)
        plt.show() 
[docs]    def refine(
        self,
        to_file: bool = True,
    ) -> None:
        """Refine the measurement with a refinement plugin.
        The refinement plugin has to be set via the method set_refinement_interface.
        Raises:
            NoRefinerSetError: If no refinement interface plugin is set for the
                measurement.
            AppNotInstalledError: If the refinement application is not installed
                on the machine.
        """
        if not self.has_refiner:
            raise NoRefinerSetError
        refinement_input_data = self._refinement_interface.file_refinement_input
        if not os.path.isfile(refinement_input_data):
            # Create refinement directory (use utils for uniform log)
            utils.make_dirs(os.path.join(refinement_input_data))
            self._refinement_interface.create_input_data()
            self._logger.info("Created refinement project")
        self._refinement_interface.open_refinement()
        self.get_refined_data(to_file=to_file)
        self.get_cif_files(to_file=to_file)
        refinement_results = self._refinement_interface.get_refinement_result()
        if to_file:
            refinement_results.to_json(self.paths.file_refinement_result)
            logger.info(
                f"Wrote refinement results to: {os.path.basename(self.paths.file_refinement_result)!r}"
            ) 
[docs]    def set_processing_state(self, state: str = None, to_file: bool = True) -> str:
        """Set the data processing state as metadata value.
        Parameter:
            :state:   Value for new processing state, besides expressions defined in
                      PROCESSING_STATES, the keyword 'reset' is accepted in order to
                      bypass a validity check and set the processing state to None.
                      If no state is provided, the state returned by the method
                      `get_processing_state` is added to the metadata.
            :to_file: Write the metadata to the JSON file if True and the new state
                      does not correspond to the initial state.
        """
        if state == "reset":
            state = None
        elif state == None:
            state = self.get_processing_state()
        elif state not in PROCESSING_STATES.values():
            raise ValueError(f"Unknown processing state: {state!r}.")
        self.update_meta("processing_state", state, to_file=to_file) 
[docs]    def set_refinement_interface(self, name: str = "profex", encoding=ENCODING) -> None:
        """Define a refinement interface plugin (default: "profex").
        The refinement interface module must be named with the filename:
        'refinement_<name>.py', and it has to be stored in the 'plugins'
        directory of this package.
        Raises:
            ValueError: If the refinement interface plugin is not registered.
        """
        plugin_loader.load_plugins("refinement")
        self._refinement_interface = refinement_interface_factory.create(
            arguments={
                "name": name,
                "measurement_id": self.meta.measurement_id,
                "data": self.data,
                "dir_refinement": self.paths.dir_refinement,
                "encoding": encoding,
            },
        )
        logger.debug(f"Set refinement interface plugin to {name!r}.")