Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Variables of the methods are defined as follows

* `Reference <http://github.com/openmicroanalysis/pyxray/blob/master/pyxray/descriptor.py>`_ object
* BibTeX key of a reference
* ``None``, the default reference will be used or the first reference found
* ``None``, the newest reference will be returned

Element properties
------------------
Expand Down
32 changes: 0 additions & 32 deletions pyxray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,38 +55,6 @@ class NotFound(Exception):


class _DatabaseMixin(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_preferred_references(self):
"""
Return the BibTeX keys of the preferred references.

If no reference is specified when calling a method, the value for the first preferred reference is returned.
If no preferred reference, the first value is returned.

:return: preferred references
:rtype: :class:`tuple`
"""
raise NotImplementedError

@abc.abstractmethod
@formatdoc(**_docextras)
def add_preferred_reference(self, reference):
"""
Adds a preferred reference.

:arg reference: :class:`Reference` or its BibTeX key

{exception}
"""
raise NotImplementedError

@abc.abstractmethod
def clear_preferred_references(self):
"""
Clear all added preferred references.
"""
raise NotImplementedError

@abc.abstractmethod
@formatdoc(**_docextras)
def element(self, element): # pragma: no cover
Expand Down
15 changes: 0 additions & 15 deletions pyxray/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
"""

__all__ = [
"get_preferred_references",
"add_preferred_reference",
"clear_preferred_references",
"element",
"element_atomic_number",
"element_symbol",
Expand Down Expand Up @@ -48,15 +45,6 @@


class _EmptyDatabase(_DatabaseMixin):
def get_preferred_references(self):
return ()

def add_preferred_reference(self, reference):
pass

def clear_preferred_references(self):
pass

def element(self, element): # pragma: no cover
raise NotFound

Expand Down Expand Up @@ -169,9 +157,6 @@ def _init_sql_database():
logger.error("No SQL database found")
database = _EmptyDatabase()

get_preferred_references = database.get_preferred_references
add_preferred_reference = database.add_preferred_reference
clear_preferred_references = database.clear_preferred_references
element = database.element
element_atomic_number = database.element_atomic_number
element_symbol = database.element_symbol
Expand Down
2 changes: 1 addition & 1 deletion pyxray/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __repr__(self):
class Reference:
bibtexkey: str
author: str = None
year: str = None
year: int = None
title: str = None
type: str = None
booktitle: str = None
Expand Down
87 changes: 30 additions & 57 deletions pyxray/sql/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@


class StatementBuilder:
def __init__(self):
def __init__(self, distinct=False):
self._distinct = distinct
self._columns = []
self._joins = {}
self._clauses = []
self._orderbys = []

def add_column(self, column):
self._columns.append(column)
Expand All @@ -34,9 +36,15 @@ def add_join(self, left, right, onclause):
def add_clause(self, clause):
self._clauses.append(clause)

def add_orderby(self, column, ascending=True):
self._orderbys.append((column, ascending))

def build(self):
statement = sqlalchemy.sql.select(self._columns)

if self._distinct:
statement = statement.distinct()

if self._joins:
# Joins have to be nested to work with sqlalchemy
# E.g.
Expand All @@ -54,13 +62,22 @@ def build(self):

statement = statement.select_from(finaljoin)

return statement.where(sqlalchemy.sql.and_(*self._clauses))
if self._clauses:
statement = statement.where(sqlalchemy.sql.and_(*self._clauses))

if self._orderbys:
orderbys = [
sqlalchemy.asc(column) if ascending else sqlalchemy.desc(column)
for column, ascending in self._orderbys
]
statement = statement.order_by(*orderbys)

return statement


class SqlDatabase(_DatabaseMixin, SqlBase):
def __init__(self, engine):
super().__init__(engine)
self._preferred_references = []

def _expand_atomic_subshell(self, atomic_subshell):
if (
Expand Down Expand Up @@ -280,10 +297,10 @@ def _update_reference(self, builder, table, reference, column="reference_id"):
reference = reference.bibtexkey

table_reference = self.require_table(descriptor.Reference)
builder.add_column(table_reference.c["bibtexkey"])
builder.add_join(
table, table_reference, table.c[column] == table_reference.c["id"]
)
builder.add_orderby(table_reference.c["year"], ascending=False) # Newest first

if reference:
builder.add_clause(table_reference.c["bibtexkey"] == reference)
Expand All @@ -308,38 +325,22 @@ def _update_notation(self, builder, table, notation):
)
builder.add_clause(table_notation.c["key"] == notation)

def _execute(self, builder, remove_bibtexkey_column=True):
def _execute(self, builder):
statement = builder.build()
logger.debug(statement.compile())

# Execute
with self.engine.connect() as conn:
rows = conn.execute(statement).fetchall()
if not rows:
row = conn.execute(statement).first()
if not row:
raise NotFound

# Only one row, no need to check for the preferred references
elif len(rows) == 1:
row = rows[0]

else:
dictrows = dict((dict(row).get("bibtexkey", None), row) for row in rows)
row = rows[0] # fall back
for bibtexkey in self._preferred_references:
if bibtexkey in dictrows:
row = dictrows[bibtexkey]
break

row = dict(row)
if remove_bibtexkey_column:
row.pop("bibtexkey", None)

if len(row) == 1:
return row.popitem()[1]
return row[0]
else:
return row.values()
return row

def _execute_many(self, builder, remove_bibtexkey_column=True):
def _execute_many(self, builder):
statement = builder.build()
logger.debug(statement.compile())

Expand All @@ -349,35 +350,7 @@ def _execute_many(self, builder, remove_bibtexkey_column=True):
if not rows:
raise NotFound

outrows = []
for row in rows:
row = dict(row)
if remove_bibtexkey_column:
row.pop("bibtexkey", None)
outrows.append(list(row.values()))

return outrows

def get_preferred_references(self):
return tuple(self._preferred_references)

def add_preferred_reference(self, reference):
if isinstance(reference, descriptor.Reference):
reference = reference.bibtexkey
if reference in self._preferred_references:
return

table = self.require_table(descriptor.Reference)

builder = StatementBuilder()
self._update_reference(builder, table, reference, "id")

bibtexkey = self._execute(builder, remove_bibtexkey_column=False)

self._preferred_references.append(bibtexkey)

def clear_preferred_references(self):
self._preferred_references.clear()
return rows

def element(self, element):
table = self.require_table(descriptor.Element)
Expand Down Expand Up @@ -443,7 +416,7 @@ def element_xray_transitions(self, element, xray_transition=None, reference=None
table_xray = self.require_table(descriptor.XrayTransition)
table_probability = self.require_table(property.XrayTransitionProbability)

builder = StatementBuilder()
builder = StatementBuilder(distinct=True)
builder.add_column(table_xray.c["source_principal_quantum_number"])
builder.add_column(table_xray.c["source_azimuthal_quantum_number"])
builder.add_column(table_xray.c["source_total_angular_momentum_nominator"])
Expand Down Expand Up @@ -479,7 +452,7 @@ def element_xray_transitions(self, element, xray_transition=None, reference=None
table_relative_weight = self.require_table(
property.XrayTransitionRelativeWeight
)
builder = StatementBuilder()
builder = StatementBuilder(distinct=True)
builder.add_column(table_xray.c["source_principal_quantum_number"])
builder.add_column(table_xray.c["source_azimuthal_quantum_number"])
builder.add_column(table_xray.c["source_total_angular_momentum_nominator"])
Expand Down
4 changes: 2 additions & 2 deletions tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

class MockParser(_Parser):
def __iter__(self):
reference = descriptor.Reference("lee1966")
reference2 = descriptor.Reference("doe2016")
reference = descriptor.Reference("lee1966", year=1966)
reference2 = descriptor.Reference("doe2016", year=2016)
element = descriptor.Element(118)
atomic_shell = descriptor.AtomicShell(1)
transition = descriptor.XrayTransition(L3, K)
Expand Down
35 changes: 2 additions & 33 deletions tests/sql/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
""" """

# Standard library modules.
import os
import tempfile
import shutil
import sqlite3

# Third party modules.
import pytest
Expand All @@ -31,22 +27,6 @@ def database_real(tmp_path):
return pyxray.data.database


def test_add_preferred_reference(database):
database.clear_preferred_references()
database.add_preferred_reference("lee1966")

assert len(database.get_preferred_references()) == 1
assert "lee1966" in database.get_preferred_references()

database.clear_preferred_references()
assert len(database.get_preferred_references()) == 0


def test_add_preferred_reference_not_found(database):
with pytest.raises(NotFound):
database.add_preferred_reference("foo")


@pytest.mark.parametrize("element", [118, "Vi", "Vibranium"])
def test_element(database, element):
assert database.element(element) == descriptor.Element(118)
Expand Down Expand Up @@ -108,7 +88,8 @@ def test_element_name_notfound_wrong_reference(database):


def test_element_atomic_weight_no_reference(database):
assert database.element_atomic_weight(118) == pytest.approx(999.1, abs=1e-2)
# doe2016 (111.1) takes precedence over lee1966 (999.1) because it is newer
assert database.element_atomic_weight(118) == pytest.approx(111.1, abs=1e-2)


def test_element_atomic_weight_lee1966(database):
Expand All @@ -123,18 +104,6 @@ def test_element_atomic_weight_doe2016(database):
)


def test_element_atomic_weight_preferred_reference(database):
database.clear_preferred_references()
database.add_preferred_reference("doe2016")
assert database.element_atomic_weight(118) == pytest.approx(111.1, abs=1e-2)

database.clear_preferred_references()
database.add_preferred_reference("lee1966")
assert database.element_atomic_weight(118) == pytest.approx(999.1, abs=1e-2)

database.clear_preferred_references()


@pytest.mark.parametrize("element", [118, "Vi", "Vibranium"])
def test_element_mass_density_kg_per_m3(database, element):
assert database.element_mass_density_kg_per_m3(element) == pytest.approx(
Expand Down