diff --git a/src/elli/dispersions/base_dispersion.py b/src/elli/dispersions/base_dispersion.py index 99090435..de4c6a69 100644 --- a/src/elli/dispersions/base_dispersion.py +++ b/src/elli/dispersions/base_dispersion.py @@ -1,7 +1,7 @@ # Encoding: utf-8 """Abstract base class and utility classes for pyElli dispersion""" from abc import ABC, abstractmethod -from typing import Union +from typing import List, Union import numpy as np import numpy.typing as npt @@ -101,13 +101,32 @@ def add(self, *args, **kwargs) -> "Dispersion": return self - def __add__(self, other: Union[int, float, "Dispersion"]) -> "Dispersion": + def _check_valid_operand(self, other: Union[int, float, "Dispersion"]): + if not isinstance(other, (int, float, Dispersion)): + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def _is_non_std_dispersion(self, other: Union[int, float, "Dispersion"]) -> bool: + return isinstance(other, (IndexDispersion, dispersions.Table)) + + def __radd__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum": """Add up the dielectric function of multiple models""" - if isinstance(other, (int, float)): - return DispersionSum(self, dispersions.EpsilonInf(eps=other)) + return self.__add__(other) - if not isinstance(other, Dispersion): - raise TypeError(f"Invalid type {type(other)} added to dispersion") + def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum": + """Add up the dielectric function of multiple models""" + self._check_valid_operand(other) + + if self._is_non_std_dispersion(other): + return other.__add__(self) + + if isinstance(other, DispersionSum): + other.dispersions.append(self) + return other + + if isinstance(other, (int, float)): + return DispersionSum(self, dispersions.EpsilonInf(other)) return DispersionSum(self, other) @@ -195,6 +214,36 @@ def _dict_to_str(dic): ) +class IndexDispersion(Dispersion): + """A dispersion based on a refractive index formulation.""" + + @abstractmethod + def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray: + """Calculates the refractive index in a given wavelength window. + + Args: + lbda (npt.ArrayLike): The wavelength window with unit nm. + + Returns: + npt.NDArray: The refractive index for each wavelength point. + """ + + def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum": + self._check_valid_operand(other) + + if isinstance(other, IndexDispersion): + raise NotImplementedError( + "Adding of index based dispersions is not supported yet" + ) + + raise TypeError( + "Cannot add refractive index and dielectric function based dispersions." + ) + + def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: + return self.refractive_index(lbda) ** 2 + + class DispersionFactory: """A factory class for dispersion objects""" @@ -220,12 +269,30 @@ def get_dispersion(identifier: str, *args, **kwargs) -> Dispersion: class DispersionSum(Dispersion): """Represents a sum of two dispersions""" - single_params_template = {} - rep_params_template = {} + single_params_template: dict = {} + rep_params_template: dict = {} + dispersions: List[Dispersion] def __init__(self, *disps: Dispersion) -> None: super().__init__() - self.dispersions = disps + self.dispersions = list(disps) + + def __add__(self, other: Union[int, float, "Dispersion"]) -> "DispersionSum": + self._check_valid_operand(other) + + if self._is_non_std_dispersion(other): + return other.__add__(self) + + if isinstance(other, DispersionSum): + self.dispersions += other.dispersions + return self + + if isinstance(other, (int, float)): + self.dispersions.append(dispersions.EpsilonInf(eps=other)) + return self + + self.dispersions.append(other) + return self def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: dielectric_function = sum( diff --git a/src/elli/dispersions/cauchy.py b/src/elli/dispersions/cauchy.py index 854b7368..0bea46c6 100644 --- a/src/elli/dispersions/cauchy.py +++ b/src/elli/dispersions/cauchy.py @@ -2,10 +2,10 @@ """Cauchy dispersion.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import IndexDispersion -class Cauchy(Dispersion): +class Cauchy(IndexDispersion): r"""Cauchy dispersion. Single parameters: @@ -30,8 +30,8 @@ class Cauchy(Dispersion): single_params_template = {"n0": 1.5, "n1": 0, "n2": 0, "k0": 0, "k1": 0, "k2": 0} rep_params_template = {} - def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: - refr_index = ( + def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray: + return ( self.single_params.get("n0") + 1e2 * self.single_params.get("n1") / lbda**2 + 1e7 * self.single_params.get("n2") / lbda**4 @@ -42,4 +42,3 @@ def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: + 1e7 * self.single_params.get("k2") / lbda**4 ) ) - return refr_index**2 diff --git a/src/elli/dispersions/cauchy_custom.py b/src/elli/dispersions/cauchy_custom.py index 0dd70708..0acaafcf 100644 --- a/src/elli/dispersions/cauchy_custom.py +++ b/src/elli/dispersions/cauchy_custom.py @@ -2,10 +2,10 @@ """Cauchy dispersion with custom exponents.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import IndexDispersion -class CauchyCustomExponent(Dispersion): +class CauchyCustomExponent(IndexDispersion): r"""Cauchy dispersion with custom exponents. Single parameters: @@ -24,9 +24,7 @@ class CauchyCustomExponent(Dispersion): single_params_template = {"n0": 1.5} rep_params_template = {"f": 0, "e": 1} - def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: - refr_index = self.single_params.get("n0") + sum( + def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray: + return self.single_params.get("n0") + sum( c.get("f") * lbda ** c.get("e") for c in self.rep_params ) - - return refr_index**2 diff --git a/src/elli/dispersions/constant_refractive_index.py b/src/elli/dispersions/constant_refractive_index.py index 9ffb8ae9..0c3ac787 100644 --- a/src/elli/dispersions/constant_refractive_index.py +++ b/src/elli/dispersions/constant_refractive_index.py @@ -2,10 +2,10 @@ """Constant refractive index.""" import numpy.typing as npt -from .base_dispersion import Dispersion +from .base_dispersion import IndexDispersion -class ConstantRefractiveIndex(Dispersion): +class ConstantRefractiveIndex(IndexDispersion): r"""Constant refractive index. Single parameters: @@ -18,9 +18,8 @@ class ConstantRefractiveIndex(Dispersion): .. math:: \varepsilon(\lambda) = \boldsymbol{n}^2 """ - single_params_template = {"n": 1} rep_params_template = {} - def dielectric_function(self, _: npt.ArrayLike) -> npt.NDArray: - return self.single_params.get("n") ** 2 + def refractive_index(self, _: npt.ArrayLike) -> npt.NDArray: + return self.single_params.get("n") diff --git a/src/elli/dispersions/table_epsilon.py b/src/elli/dispersions/table_epsilon.py index 5f8ea62f..57afa6da 100644 --- a/src/elli/dispersions/table_epsilon.py +++ b/src/elli/dispersions/table_epsilon.py @@ -1,5 +1,6 @@ # Encoding: utf-8 """Dispersion specified by a table of wavelengths (nm) and dielectric function values.""" +from typing import Union import numpy as np import numpy.typing as npt import scipy.interpolate @@ -49,5 +50,8 @@ def __init__(self, *args, **kwargs) -> None: kind="cubic", ) + def __add__(self, _: Union[int, float, "Dispersion"]) -> "DispersionSum": + raise NotImplementedError("Adding of tabular dispersions is not yet supported") + def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: return self.interpolation(lbda) diff --git a/src/elli/dispersions/table_index.py b/src/elli/dispersions/table_index.py index e1bc9436..e6281c79 100644 --- a/src/elli/dispersions/table_index.py +++ b/src/elli/dispersions/table_index.py @@ -4,10 +4,10 @@ import numpy.typing as npt import scipy.interpolate -from .base_dispersion import Dispersion, InvalidParameters +from .base_dispersion import IndexDispersion, InvalidParameters -class Table(Dispersion): +class Table(IndexDispersion): """Dispersion specified by a table of wavelengths (nm) and refractive index values. Please not that this model will produce errors for wavelengths outside the provided wavelength range. @@ -40,9 +40,9 @@ def __init__(self, *args, **kwargs) -> None: self.interpolation = scipy.interpolate.interp1d( self.single_params.get("lbda"), - self.single_params.get("n") ** 2, + self.single_params.get("n"), kind="cubic", ) - def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: + def refractive_index(self, lbda: npt.ArrayLike) -> npt.NDArray: return self.interpolation(lbda) diff --git a/tests/test_dispersion_adding.py b/tests/test_dispersion_adding.py new file mode 100644 index 00000000..172f7346 --- /dev/null +++ b/tests/test_dispersion_adding.py @@ -0,0 +1,73 @@ +"""Test adding of dispersions""" +import pytest +from numpy.testing import assert_array_almost_equal +from elli import Cauchy, Sellmeier +from elli.dispersions.base_dispersion import DispersionSum +from elli.dispersions.table_epsilon import TableEpsilon + + +def test_fail_on_adding_index_dispersion(): + """Test whether adding for an index based model fails""" + cauchy_err_str = "Adding of index based dispersions is not supported yet" + with pytest.raises(NotImplementedError) as sum_err: + _ = Cauchy() + Cauchy() + + assert cauchy_err_str in str(sum_err.value) + + +def test_fail_on_adding_index_and_diel_dispersion(): + """Test whether the adding fails for an index based and dielectric dispersion""" + + for disp in [1, Sellmeier()]: + with pytest.raises(TypeError) as sum_err: + _ = disp + Cauchy() + + assert ( + "Cannot add refractive index and dielectric function based dispersions." + in str(sum_err.value) + ) + + +def test_adding_of_diel_dispersions(): + """Test if dielectric dispersions are added correctly""" + + dispersion_sum = Sellmeier() + Sellmeier() + + assert isinstance(dispersion_sum, DispersionSum) + assert len(dispersion_sum.dispersions) == 2 + + for disp in dispersion_sum.dispersions: + assert isinstance(disp, Sellmeier) + + assert_array_almost_equal( + dispersion_sum.get_dielectric_df().values, + 2 * Sellmeier().get_dielectric_df().values, + ) + + +def test_flat_dispersion_sum_on_multiple_add(): + """Test whether the DispersionSum stays flat on multiple adds""" + + dispersion_sum = Sellmeier() + Sellmeier() + Sellmeier() + + assert isinstance(dispersion_sum, DispersionSum) + assert len(dispersion_sum.dispersions) == 3 + + for disp in dispersion_sum.dispersions: + assert isinstance(disp, Sellmeier) + + assert_array_almost_equal( + dispersion_sum.get_dielectric_df().values, + 3 * Sellmeier().get_dielectric_df().values, + ) + + +def test_adding_of_tabular_dispersions(): + """Tests correct adding of tabular dispersions""" + + with pytest.raises(NotImplementedError) as not_impl_err: + _ = TableEpsilon() + 1 + + assert ( + str(not_impl_err.value) == "Adding of tabular dispersions is not yet supported" + )