7373 python-version : ["3.8", "3.11"]
7474 fast-compile : [0,1]
7575 float32 : [0,1]
76- install-numba : [1]
76+ install-numba : [0]
77+ install-jax : [0]
7778 part :
7879 - " tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
7980 - " tests/scan"
9394 part : " tests/tensor/test_math.py"
9495 - fast-compile : 1
9596 float32 : 1
97+ include :
98+ - install-numba : 1
99+ python-version : " 3.8"
100+ fast-compile : 0
101+ float32 : 0
102+ part : " tests/link/numba"
103+ - install-numba : 1
104+ python-version : " 3.11"
105+ fast-compile : 0
106+ float32 : 0
107+ part : " tests/link/numba"
108+ - install-jax : 1
109+ python-version : " 3.8"
110+ fast-compile : 0
111+ float32 : 0
112+ part : " tests/link/jax"
113+ - install-jax : 1
114+ python-version : " 3.11"
115+ fast-compile : 0
116+ float32 : 0
117+ part : " tests/link/jax"
96118 steps :
97119 - uses : actions/checkout@v3
98120 with :
@@ -118,15 +140,20 @@ jobs:
118140 shell : bash -l {0}
119141 run : |
120142 mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
121- if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
122- mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
143+ # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but
144+ # not numpy, even though scipy 1.7 requires numpy<1.23. When installing
145+ # PyTensor next, pip installs a lower version of numpy via the PyPI.
146+ if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
147+ if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
148+ if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
123149 pip install -e ./
124150 mamba list && pip freeze
125151 python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
126- python -c 'import pytensor; assert( pytensor.config.blas__ldflags != "") '
152+ python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty" '
127153 env :
128154 PYTHON_VERSION : ${{ matrix.python-version }}
129155 INSTALL_NUMBA : ${{ matrix.install-numba }}
156+ INSTALL_JAX : ${{ matrix.install-jax }}
130157
131158 - name : Run tests
132159 shell : bash -l {0}
@@ -175,7 +202,7 @@ jobs:
175202 pip install -e ./
176203 mamba list && pip freeze
177204 python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
178- python -c 'import pytensor; assert( pytensor.config.blas__ldflags != "") '
205+ python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty" '
179206 env :
180207 PYTHON_VERSION : 3.9
181208 - name : Download previous benchmark data
0 commit comments