diff --git a/README.rst b/README.rst index c2ba41e..4622a7c 100644 --- a/README.rst +++ b/README.rst @@ -97,7 +97,7 @@ Variables of the methods are defined as follows * `Reference `_ object * BibTeX key of a reference - * ``None``, the default reference will be used or the first reference found + * ``None``, the newest reference will be returned Element properties ------------------ diff --git a/pyxray/base.py b/pyxray/base.py index 2d880e0..8599b41 100644 --- a/pyxray/base.py +++ b/pyxray/base.py @@ -55,38 +55,6 @@ class NotFound(Exception): class _DatabaseMixin(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_preferred_references(self): - """ - Return the BibTeX keys of the preferred references. - - If no reference is specified when calling a method, the value for the first preferred reference is returned. - If no preferred reference, the first value is returned. - - :return: preferred references - :rtype: :class:`tuple` - """ - raise NotImplementedError - - @abc.abstractmethod - @formatdoc(**_docextras) - def add_preferred_reference(self, reference): - """ - Adds a preferred reference. - - :arg reference: :class:`Reference` or its BibTeX key - - {exception} - """ - raise NotImplementedError - - @abc.abstractmethod - def clear_preferred_references(self): - """ - Clear all added preferred references. - """ - raise NotImplementedError - @abc.abstractmethod @formatdoc(**_docextras) def element(self, element): # pragma: no cover diff --git a/pyxray/data.py b/pyxray/data.py index 5042c82..a7e900f 100644 --- a/pyxray/data.py +++ b/pyxray/data.py @@ -3,9 +3,6 @@ """ __all__ = [ - "get_preferred_references", - "add_preferred_reference", - "clear_preferred_references", "element", "element_atomic_number", "element_symbol", @@ -48,15 +45,6 @@ class _EmptyDatabase(_DatabaseMixin): - def get_preferred_references(self): - return () - - def add_preferred_reference(self, reference): - pass - - def clear_preferred_references(self): - pass - def element(self, element): # pragma: no cover raise NotFound @@ -169,9 +157,6 @@ def _init_sql_database(): logger.error("No SQL database found") database = _EmptyDatabase() -get_preferred_references = database.get_preferred_references -add_preferred_reference = database.add_preferred_reference -clear_preferred_references = database.clear_preferred_references element = database.element element_atomic_number = database.element_atomic_number element_symbol = database.element_symbol diff --git a/pyxray/descriptor.py b/pyxray/descriptor.py index 0362af5..6255cd2 100644 --- a/pyxray/descriptor.py +++ b/pyxray/descriptor.py @@ -352,7 +352,7 @@ def __repr__(self): class Reference: bibtexkey: str author: str = None - year: str = None + year: int = None title: str = None type: str = None booktitle: str = None diff --git a/pyxray/sql/data.py b/pyxray/sql/data.py index 7c58462..08beaee 100644 --- a/pyxray/sql/data.py +++ b/pyxray/sql/data.py @@ -18,10 +18,12 @@ class StatementBuilder: - def __init__(self): + def __init__(self, distinct=False): + self._distinct = distinct self._columns = [] self._joins = {} self._clauses = [] + self._orderbys = [] def add_column(self, column): self._columns.append(column) @@ -34,9 +36,15 @@ def add_join(self, left, right, onclause): def add_clause(self, clause): self._clauses.append(clause) + def add_orderby(self, column, ascending=True): + self._orderbys.append((column, ascending)) + def build(self): statement = sqlalchemy.sql.select(self._columns) + if self._distinct: + statement = statement.distinct() + if self._joins: # Joins have to be nested to work with sqlalchemy # E.g. @@ -54,13 +62,22 @@ def build(self): statement = statement.select_from(finaljoin) - return statement.where(sqlalchemy.sql.and_(*self._clauses)) + if self._clauses: + statement = statement.where(sqlalchemy.sql.and_(*self._clauses)) + + if self._orderbys: + orderbys = [ + sqlalchemy.asc(column) if ascending else sqlalchemy.desc(column) + for column, ascending in self._orderbys + ] + statement = statement.order_by(*orderbys) + + return statement class SqlDatabase(_DatabaseMixin, SqlBase): def __init__(self, engine): super().__init__(engine) - self._preferred_references = [] def _expand_atomic_subshell(self, atomic_subshell): if ( @@ -280,10 +297,10 @@ def _update_reference(self, builder, table, reference, column="reference_id"): reference = reference.bibtexkey table_reference = self.require_table(descriptor.Reference) - builder.add_column(table_reference.c["bibtexkey"]) builder.add_join( table, table_reference, table.c[column] == table_reference.c["id"] ) + builder.add_orderby(table_reference.c["year"], ascending=False) # Newest first if reference: builder.add_clause(table_reference.c["bibtexkey"] == reference) @@ -308,38 +325,22 @@ def _update_notation(self, builder, table, notation): ) builder.add_clause(table_notation.c["key"] == notation) - def _execute(self, builder, remove_bibtexkey_column=True): + def _execute(self, builder): statement = builder.build() logger.debug(statement.compile()) # Execute with self.engine.connect() as conn: - rows = conn.execute(statement).fetchall() - if not rows: + row = conn.execute(statement).first() + if not row: raise NotFound - # Only one row, no need to check for the preferred references - elif len(rows) == 1: - row = rows[0] - - else: - dictrows = dict((dict(row).get("bibtexkey", None), row) for row in rows) - row = rows[0] # fall back - for bibtexkey in self._preferred_references: - if bibtexkey in dictrows: - row = dictrows[bibtexkey] - break - - row = dict(row) - if remove_bibtexkey_column: - row.pop("bibtexkey", None) - if len(row) == 1: - return row.popitem()[1] + return row[0] else: - return row.values() + return row - def _execute_many(self, builder, remove_bibtexkey_column=True): + def _execute_many(self, builder): statement = builder.build() logger.debug(statement.compile()) @@ -349,35 +350,7 @@ def _execute_many(self, builder, remove_bibtexkey_column=True): if not rows: raise NotFound - outrows = [] - for row in rows: - row = dict(row) - if remove_bibtexkey_column: - row.pop("bibtexkey", None) - outrows.append(list(row.values())) - - return outrows - - def get_preferred_references(self): - return tuple(self._preferred_references) - - def add_preferred_reference(self, reference): - if isinstance(reference, descriptor.Reference): - reference = reference.bibtexkey - if reference in self._preferred_references: - return - - table = self.require_table(descriptor.Reference) - - builder = StatementBuilder() - self._update_reference(builder, table, reference, "id") - - bibtexkey = self._execute(builder, remove_bibtexkey_column=False) - - self._preferred_references.append(bibtexkey) - - def clear_preferred_references(self): - self._preferred_references.clear() + return rows def element(self, element): table = self.require_table(descriptor.Element) @@ -443,7 +416,7 @@ def element_xray_transitions(self, element, xray_transition=None, reference=None table_xray = self.require_table(descriptor.XrayTransition) table_probability = self.require_table(property.XrayTransitionProbability) - builder = StatementBuilder() + builder = StatementBuilder(distinct=True) builder.add_column(table_xray.c["source_principal_quantum_number"]) builder.add_column(table_xray.c["source_azimuthal_quantum_number"]) builder.add_column(table_xray.c["source_total_angular_momentum_nominator"]) @@ -479,7 +452,7 @@ def element_xray_transitions(self, element, xray_transition=None, reference=None table_relative_weight = self.require_table( property.XrayTransitionRelativeWeight ) - builder = StatementBuilder() + builder = StatementBuilder(distinct=True) builder.add_column(table_xray.c["source_principal_quantum_number"]) builder.add_column(table_xray.c["source_azimuthal_quantum_number"]) builder.add_column(table_xray.c["source_total_angular_momentum_nominator"]) diff --git a/tests/sql/conftest.py b/tests/sql/conftest.py index bb6685d..5ac93a0 100644 --- a/tests/sql/conftest.py +++ b/tests/sql/conftest.py @@ -20,8 +20,8 @@ class MockParser(_Parser): def __iter__(self): - reference = descriptor.Reference("lee1966") - reference2 = descriptor.Reference("doe2016") + reference = descriptor.Reference("lee1966", year=1966) + reference2 = descriptor.Reference("doe2016", year=2016) element = descriptor.Element(118) atomic_shell = descriptor.AtomicShell(1) transition = descriptor.XrayTransition(L3, K) diff --git a/tests/sql/test_data.py b/tests/sql/test_data.py index 98c7e4c..ad5f118 100644 --- a/tests/sql/test_data.py +++ b/tests/sql/test_data.py @@ -2,10 +2,6 @@ """ """ # Standard library modules. -import os -import tempfile -import shutil -import sqlite3 # Third party modules. import pytest @@ -31,22 +27,6 @@ def database_real(tmp_path): return pyxray.data.database -def test_add_preferred_reference(database): - database.clear_preferred_references() - database.add_preferred_reference("lee1966") - - assert len(database.get_preferred_references()) == 1 - assert "lee1966" in database.get_preferred_references() - - database.clear_preferred_references() - assert len(database.get_preferred_references()) == 0 - - -def test_add_preferred_reference_not_found(database): - with pytest.raises(NotFound): - database.add_preferred_reference("foo") - - @pytest.mark.parametrize("element", [118, "Vi", "Vibranium"]) def test_element(database, element): assert database.element(element) == descriptor.Element(118) @@ -108,7 +88,8 @@ def test_element_name_notfound_wrong_reference(database): def test_element_atomic_weight_no_reference(database): - assert database.element_atomic_weight(118) == pytest.approx(999.1, abs=1e-2) + # doe2016 (111.1) takes precedence over lee1966 (999.1) because it is newer + assert database.element_atomic_weight(118) == pytest.approx(111.1, abs=1e-2) def test_element_atomic_weight_lee1966(database): @@ -123,18 +104,6 @@ def test_element_atomic_weight_doe2016(database): ) -def test_element_atomic_weight_preferred_reference(database): - database.clear_preferred_references() - database.add_preferred_reference("doe2016") - assert database.element_atomic_weight(118) == pytest.approx(111.1, abs=1e-2) - - database.clear_preferred_references() - database.add_preferred_reference("lee1966") - assert database.element_atomic_weight(118) == pytest.approx(999.1, abs=1e-2) - - database.clear_preferred_references() - - @pytest.mark.parametrize("element", [118, "Vi", "Vibranium"]) def test_element_mass_density_kg_per_m3(database, element): assert database.element_mass_density_kg_per_m3(element) == pytest.approx(