File tree Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -51,13 +51,16 @@ RUN pip uninstall -y horovod && \
5151 /tmp/clean-layer.sh
5252{{ end }}
5353
54+ {{ if eq .Accelerator "gpu" }}
55+ # b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
56+ RUN rm /etc/apt/sources.list.d/cuda.list && \
57+ apt-key del 7fa2af80 && \
58+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
59+ {{ end }}
60+
5461# Use a fixed apt-get repo to stop intermittent failures due to flaky httpredir connections,
5562# as described by Lionel Chan at http://stackoverflow.com/a/37426929/5881346
5663RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list && \
57- # b/230864778: Temporarily swap the NVIDIA GPG key. Remove once new base image with new GPG key is released.
58- rm /etc/apt/sources.list.d/cuda.list && \
59- apt-key del 7fa2af80 && \
60- apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub && \
6164 apt-get update && \
6265 # Needed by lightGBM (GPU build)
6366 # https://lightgbm.readthedocs.io/en/latest/GPU-Tutorial.html#build-lightgbm
@@ -491,6 +494,7 @@ RUN pip install flashtext && \
491494 pip install bqplot && \
492495 pip install earthengine-api && \
493496 pip install transformers && \
497+ pip install datasets && \
494498 pip install dlib && \
495499 pip install kaggle-environments && \
496500 pip install geopandas && \
Original file line number Diff line number Diff line change 1+ import unittest
2+
3+ from datasets import Dataset
4+
5+
6+ class TestHuggingFaceDatasets (unittest .TestCase ):
7+
8+ def test_map (self ):
9+ def some_func (batch ):
10+ batch ['label' ] = 'foo'
11+ return batch
12+
13+ df = Dataset .from_dict ({'text' : ['Kaggle rocks!' ]})
14+ mapped_df = df .map (some_func )
15+
16+ self .assertEqual ('foo' , mapped_df [0 ]['label' ])
You can’t perform that action at this time.
0 commit comments