Source code for xrd_tools.database

import inspect
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

from .analyse_measurements import AnalyseMeasurements
from .config import DATABASE, FILE_SUFFIXES, MEASUREMENTS_DIR
from .measurement import Measurement
from .measurement_manager import MeasurementManager
from .meta import Meta

logger = logging.getLogger(__name__)


[docs]@dataclass(kw_only=True) class DBentry: "A XRD measurement database entry template." measurement_id: str """str: The ID of the measurement.""" sample: str """str: The ID of the sample that was measured.""" compound: str """str: The formula of the compound that was measured.""" description: str """str: A description of the measurement.""" comment: str """str: A comment about the measurement.""" ht_mode: bool """bool: Whether the measurement was taken in high temperature mode.""" processing_state: str = None """str: The processing state of the measurement, expression must be predefined.""" xrd_datetime: str = None """str: The date and time at which the measurement was taken.""" date_added: str """str: The date and time at when the measurement was added to the DB."""
[docs]@dataclass class MeasurementDatabase: """XRD Measurement database Args: db_file (str): Path to the database file (CSV). measurements_dir (str): Path to measurements directory. """ db_file: str = DATABASE measurements_dir: str = MEASUREMENTS_DIR def __post_init__(self): self._m_manager = MeasurementManager(self.measurements_dir) self.entries = self._load_data() self.update_db_file() def _load_data(self) -> pd.DataFrame: if not os.path.isfile(self.db_file): return pd.DataFrame() return pd.read_csv( self.db_file, index_col="measurement_id", dtype={"measurement_id": str} ).sort_index(ascending=False) def _validate_new_id(self, new_id: str) -> None: """Ensure new_id does not exist in DB file.""" if new_id in self.entries.index: raise ValueError(f"Measurement ID {new_id!r} already existing in DB.")
[docs] def add_measurement(self, meta_obj: Meta, to_file: bool = True) -> None: """Add a new measurement to the database. Attributes ---------- meta_obj : Meta object of measurement to be added to database. to_file : Flag indicating whether the updated DB is written to its `csv` file. """ args_class = list(inspect.signature(DBentry).parameters) kwargs = { k: meta_obj.__dict__[k] for k in args_class if k in meta_obj.__dict__.keys() } kwargs["date_added"] = datetime.now().isoformat() if self.empty: self.entries = pd.DataFrame([kwargs]).set_index("measurement_id") else: self._validate_new_id(kwargs["measurement_id"]) dict_df = pd.DataFrame([kwargs]).set_index("measurement_id") self.entries = pd.concat([self.entries, dict_df]) if to_file: self.to_file()
[docs] def to_file(self) -> None: """Write the database content to its CSV file.""" self.entries.to_csv(self.db_file) logger.debug(f"Database written to {self.db_file!r}")
@property def empty(self) -> bool: """Returns True if no measurements are registered.""" if len(self.entries) > 0: return False return True
[docs] def list_measurements(self) -> str: """Return string with table of database entries.""" if self.empty: return "No measurements registered in database" return str(self.entries)
[docs] def list_measurements_for_db_key(self, column: str, value: str) -> list[str]: """Returns list of measurement IDs for provided DB column and value(s). Method can be used to filter the database for a specific data subset, e.g. measurements corresponding to a certain processing state. Attributes ---------- column : Column where value(s) exists (e.g. `processing_state`). value : Value(s) to filter for (e.g. `refined`). """ df = ( self.entries.reset_index() .sort_values(["sample", "measurement_id"]) .set_index("measurement_id") ) return list(df[df[column] == value].index.values)
[docs] def list_compounds(self) -> list[str]: """Returns a list with unique compounds registered in the DB (sorted).""" return list(self.entries["compound"].dropna().sort_values().unique())
[docs] def list_samples(self) -> list[str]: """Returns a list with sample IDs registered in the DB (sorted).""" return list(self.entries["sample"].dropna().sort_values().unique())
[docs] def get_measurement_for_id(self, measurement_id: str) -> Measurement: """Returns a measurement object for the provided measurement ID.""" return self._m_manager.get_measurement(measurement_id)
[docs] def get_measurement_for_sample(self, sample: str) -> Measurement: """Returns a measurement object for the provided sample ID.""" measurement_ids = self.list_measurements_for_db_key("sample", sample) if len(measurement_ids) == 1: return self._m_manager.get_measurement(measurement_ids[0]) elif len(measurement_ids) == 0: try: return self._m_manager.get_measurement(sample) except IndexError: raise IndexError(f"No measurement registered for sample '{sample}'.") else: raise ValueError( f"Multiple measurements registered for '{sample}'." + f"\n Provide one of the measurement IDs instead:" + f"\n {measurement_ids}" )
[docs] def get_measurements_for_compound(self, compound: str) -> AnalyseMeasurements: """ Returns an AnalyseMeasurements object with measurement of the provided compound. """ measurement_ids = self.list_measurements_for_db_key("compound", compound) measurements = [self._m_manager.get_measurement(i) for i in measurement_ids] return AnalyseMeasurements(measurements)
[docs] def update_db_file(self, meta_suffix=FILE_SUFFIXES["meta"]) -> None: """ Search for all meta data files in measurements subdirectories, and update database file with the meta information of unregistered measurements. """ counter = 0 cwd = os.getcwd() os.chdir(self.measurements_dir) logger.debug(f"Searching data in {self.measurements_dir!r}...") meta_objs = [] for root, _, files in os.walk(self.measurements_dir): for file in files: if file.endswith(meta_suffix + ".json"): meta_obj = Meta.from_json(os.path.join(root, file)) if meta_obj.measurement_id not in self.entries.index: meta_objs.append(meta_obj) os.chdir(cwd) if len(meta_objs) == 0: logger.info("No unregistered measurement found.") return None logger.debug(f"Found {len(meta_objs)} unregistered measurements.") count = 0 for meta_obj in tqdm(meta_objs, desc="Adding measurements"): self.add_measurement(meta_obj, to_file=False) count += 1 if count == 0: logger.info("No measurement added.") return None logger.info(f"DB file '{self.db_file}' updated with {count} measurements.") self.to_file()