@@ -99,12 +99,24 @@ RUN conda config --add channels nvidia && \
9999 mamba install -y mkl cartopy imagemagick pyproj "shapely<2" && \
100100 /tmp/clean-layer.sh
101101
102+ # Install spacy
103+ {{ if eq .Accelerator "gpu" }}
104+ RUN mamba install -y -c conda-forge spacy cupy cuda-version=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION && \
105+ /tmp/clean-layer.sh
106+ {{ else }}
107+ RUN pip install spacy && \
108+ /tmp/clean-layer.sh
109+ {{ end}}
102110{{ if eq .Accelerator "gpu" }}
103111
104112# b/232247930: uninstall pyarrow to avoid double installation with the GPU specific version.
105113RUN pip uninstall -y pyarrow && \
106114 mamba install -y cudf cuml && \
107115 /tmp/clean-layer.sh
116+
117+ # TODO: b/296444923 - Resolve pandas dependency another way
118+ RUN sed -i 's/^is_extension_type/# is_extension_type/g' /opt/conda/lib/python3.10/site-packages/cudf/api/types.py \
119+ && sed -i 's/^is_categorical/# is_categorical/g' /opt/conda/lib/python3.10/site-packages/cudf/api/types.py
108120{{ end }}
109121
110122# Install PyTorch
@@ -150,14 +162,6 @@ RUN pip install jax[cpu] && \
150162 /tmp/clean-layer.sh
151163{{ end }}
152164
153- # Install spacy
154- {{ if eq .Accelerator "gpu" }}
155- RUN mamba install -y -c conda-forge spacy cupy && \
156- /tmp/clean-layer.sh
157- {{ else }}
158- RUN pip install spacy && \
159- /tmp/clean-layer.sh
160- {{ end}}
161165
162166# Install GPU specific packages
163167{{ if eq .Accelerator "gpu" }}
@@ -177,12 +181,13 @@ RUN JAXVER=$(pip freeze | grep -e "^jax==") && \
177181 pandas \
178182 polars \
179183 flax \
180- "${JAXVER}" && \
184+ "${JAXVER}"
185+
186+ RUN apt-get install -y default-jre
181187
182- # Install h2o from source.
183- # Use `conda install -c h2oai h2o` once Python 3.7 version is released to conda.
184- apt-get install -y default-jre-headless && \
185- pip install -f https://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o \
188+ RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh
189+
190+ RUN pip install \
186191 "tensorflow-gcs-config<=${TENSORFLOW_VERSION}" \
187192 "tensorflow==${TENSORFLOW_VERSION}" \
188193 tensorflow-addons \
@@ -248,7 +253,6 @@ RUN pip install scipy \
248253 datashader \
249254 # Boruta (python implementation)
250255 Boruta && \
251-
252256 apt-get install -y graphviz && pip install graphviz && \
253257 # Pandoc is a dependency of deap
254258 apt-get install -y pandoc && \
@@ -470,6 +474,7 @@ RUN pip install bleach \
470474 pyarrow \
471475 feather-format \
472476 fastai
477+
473478RUN python -m spacy download en_core_web_sm && python -m spacy download en_core_web_lg && \
474479 apt-get update && apt-get install -y ffmpeg && \
475480 /tmp/clean-layer.sh
@@ -678,4 +683,4 @@ RUN echo "$GIT_COMMIT" > /etc/git_commit && echo "$BUILD_DATE" > /etc/build_date
678683{{ if eq .Accelerator "gpu" }}
679684# Remove the CUDA stubs.
680685ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH_NO_STUBS"
681- {{ end }}
686+ {{ end }}
0 commit comments