From 823bd698a075cb20c82b40daf533e39ebf4cf57c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 19:58:17 +0100 Subject: [PATCH 1/6] copy torchtext batch --- pytorch_lightning/utilities/apply_func.py | 31 ++---- pytorch_lightning/utilities/imports.py | 1 - .../utilities/torchtext_batch.py | 101 ++++++++++++++++++ requirements/extra.txt | 1 - requirements/test.txt | 1 + tests/helpers/imports.py | 8 ++ tests/models/test_gpu.py | 2 +- tests/utilities/test_apply_func_torchtext.py | 9 +- 8 files changed, 125 insertions(+), 29 deletions(-) create mode 100644 pytorch_lightning/utilities/torchtext_batch.py create mode 100644 tests/helpers/imports.py diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 0599cccec83be..a24780adc77f5 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -22,15 +22,7 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE - -if _TORCHTEXT_AVAILABLE: - if _module_available("torchtext.legacy.data"): - from torchtext.legacy.data import Batch - else: - from torchtext.data import Batch -else: - Batch = type(None) +from pytorch_lightning.utilities.torchtext_batch import Batch def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): @@ -142,22 +134,19 @@ def move_data_to_device(batch: Any, device: torch.device): """ def batch_to(data): - # try to move torchtext data first - if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): - - # Shallow copy because each Batch has a reference to Dataset which contains all examples - device_data = copy(data) - for field, field_value in data.dataset.fields.items(): - if field_value is None: - continue - device_field = move_data_to_device(getattr(data, field), device) - setattr(device_data, field, device_field) - return device_data + # Shallow copy because each Batch has a reference to Dataset which contains all examples + device_data = copy(data) + for field, field_value in data.dataset.fields.items(): + if field_value is None: + continue + device_field = move_data_to_device(getattr(data, field), device) + setattr(device_data, field, device_field) + return device_data kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} return data.to(device, **kwargs) - dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType + dtype = (TransferableDataType, Batch) return apply_to_collection(batch, dtype=dtype, function=batch_to) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0..e2ef04b7e7b92 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -68,6 +68,5 @@ def _compare_version(package: str, op, version) -> bool: _OMEGACONF_AVAILABLE = _module_available("omegaconf") _RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) -_TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") diff --git a/pytorch_lightning/utilities/torchtext_batch.py b/pytorch_lightning/utilities/torchtext_batch.py new file mode 100644 index 0000000000000..444d7f3afda48 --- /dev/null +++ b/pytorch_lightning/utilities/torchtext_batch.py @@ -0,0 +1,101 @@ +# THIS IS PURE COPY of TORCHTEXT BATCH CLASS +# THIS STRUCTURE SEEMS TO BE DEPRECATED IN TORCHTEXT + +import torch + + +class Batch(object): + """Defines a batch of examples along with its Fields. + + Attributes: + batch_size: Number of examples in the batch. + dataset: A reference to the dataset object the examples come from + (which itself contains the dataset's Field objects). + train: Deprecated: this attribute is left for backwards compatibility, + however it is UNUSED as of the merger with pytorch 0.4. + input_fields: The names of the fields that are used as input for the model + target_fields: The names of the fields that are used as targets during model training + + Also stores the Variable for each column in the batch as an attribute. + """ + + def __init__(self, data=None, dataset=None, device=None): + """Create a Batch from a list of examples.""" + if data is not None: + self.batch_size = len(data) + self.dataset = dataset + self.fields = dataset.fields.keys() # copy field names + self.input_fields = [k for k, v in dataset.fields.items() if v is not None and not v.is_target] + self.target_fields = [k for k, v in dataset.fields.items() if v is not None and v.is_target] + + for (name, field) in dataset.fields.items(): + if field is not None: + batch = [getattr(x, name) for x in data] + setattr(self, name, field.process(batch, device=device)) + + @classmethod + def fromvars(cls, dataset, batch_size, train=None, **kwargs): + """Create a Batch directly from a number of Variables.""" + batch = cls() + batch.batch_size = batch_size + batch.dataset = dataset + batch.fields = dataset.fields.keys() + for k, v in kwargs.items(): + setattr(batch, k, v) + return batch + + def __repr__(self): + return str(self) + + def __str__(self): + if not self.__dict__: + return 'Empty {} instance'.format(torch.typename(self)) + + fields_to_index = filter(lambda field: field is not None, self.fields) + var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name)) + for name in fields_to_index if hasattr(self, name)]) + + data_str = (' from {}'.format(self.dataset.name.upper()) + if hasattr(self.dataset, 'name') + and isinstance(self.dataset.name, str) else '') + + strt = '[{} of size {}{}]\n{}'.format(torch.typename(self), + self.batch_size, data_str, var_strs) + return '\n' + strt + + def __len__(self): + return self.batch_size + + def _get_field_values(self, fields): + if len(fields) == 0: + return None + elif len(fields) == 1: + return getattr(self, fields[0]) + else: + return tuple(getattr(self, f) for f in fields) + + def __iter__(self): + yield self._get_field_values(self.input_fields) + yield self._get_field_values(self.target_fields) + + +def _short_str(tensor): + # unwrap variable to tensor + if not torch.is_tensor(tensor): + # (1) unpack variable + if hasattr(tensor, 'data'): + tensor = getattr(tensor, 'data') + # (2) handle include_lengths + elif isinstance(tensor, tuple): + return str(tuple(_short_str(t) for t in tensor)) + # (3) fallback to default str + else: + return str(tensor) + + # copied from torch _tensor_str + size_str = 'x'.join(str(size) for size in tensor.size()) + device_str = '' if not tensor.is_cuda else \ + ' (GPU {})'.format(tensor.get_device()) + strt = '[{} of size {}{}]'.format(torch.typename(tensor), + size_str, device_str) + return strt diff --git a/requirements/extra.txt b/requirements/extra.txt index 85437327bce06..50cba404c88a7 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -3,7 +3,6 @@ matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 -torchtext>=0.5 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 diff --git a/requirements/test.txt b/requirements/test.txt index 2d47143ca58d4..955e24606ddda 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,4 +14,5 @@ pre-commit>=1.0 cloudpickle>=1.3 nltk>=3.3 +torchtext>=0.5 pandas # needed in benchmarks diff --git a/tests/helpers/imports.py b/tests/helpers/imports.py new file mode 100644 index 0000000000000..e6c6fd18b605c --- /dev/null +++ b/tests/helpers/imports.py @@ -0,0 +1,8 @@ +import operator + +from pytorch_lightning.utilities.imports import _compare_version + +if _compare_version("torch", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 +else: + from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 537dfd3e1006b..7764754594a09 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -16,7 +16,6 @@ import pytest import torch -from torchtext.data import Batch, Dataset, Example, Field, LabelField import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -25,6 +24,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index 4c9de8a1a8e25..d225c9523f0df 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -13,15 +13,14 @@ # limitations under the License. import pytest import torch -import torchtext -from torchtext.data.example import Example from pytorch_lightning.utilities.apply_func import move_data_to_device +from tests.helpers.imports import Batch, Dataset, Example, Field, Iterator from tests.helpers.runif import RunIf def _get_torchtext_data_iterator(include_lengths=False): - text_field = torchtext.data.Field( + text_field = Field( sequential=True, pad_first=False, # nosec init_token="", @@ -33,13 +32,13 @@ def _get_torchtext_data_iterator(include_lengths=False): example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) - dataset = torchtext.data.Dataset( + dataset = Dataset( [example1, example2, example3], {"text": text_field}, ) text_field.build_vocab(dataset) - iterator = torchtext.data.Iterator( + iterator = Iterator( dataset, batch_size=3, sort_key=None, From 7fb30460ea94e53bdd9b9315dc7f4bca15134db8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 20:13:02 +0100 Subject: [PATCH 2/6] update --- pytorch_lightning/utilities/apply_func.py | 33 ++++-- pytorch_lightning/utilities/imports.py | 1 + .../utilities/torchtext_batch.py | 101 ------------------ tests/utilities/test_apply_func_torchtext.py | 2 +- 4 files changed, 24 insertions(+), 113 deletions(-) delete mode 100644 pytorch_lightning/utilities/torchtext_batch.py diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index a24780adc77f5..16e0e5535278b 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import operator from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -22,7 +22,15 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.torchtext_batch import Batch +from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE, _compare_version + +if _TORCHTEXT_AVAILABLE: + if _compare_version("torch", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch + else: + from torchtext.data import Batch +else: + Batch = type(None) def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): @@ -134,19 +142,22 @@ def move_data_to_device(batch: Any, device: torch.device): """ def batch_to(data): - # Shallow copy because each Batch has a reference to Dataset which contains all examples - device_data = copy(data) - for field, field_value in data.dataset.fields.items(): - if field_value is None: - continue - device_field = move_data_to_device(getattr(data, field), device) - setattr(device_data, field, device_field) - return device_data + # try to move torchtext data first + if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): + + # Shallow copy because each Batch has a reference to Dataset which contains all examples + device_data = copy(data) + for field, field_value in data.dataset.fields.items(): + if field_value is None: + continue + device_field = move_data_to_device(getattr(data, field), device) + setattr(device_data, field, device_field) + return device_data kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {} return data.to(device, **kwargs) - dtype = (TransferableDataType, Batch) + dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index e2ef04b7e7b92..41a13d6c678a0 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -68,5 +68,6 @@ def _compare_version(package: str, op, version) -> bool: _OMEGACONF_AVAILABLE = _module_available("omegaconf") _RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) +_TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") diff --git a/pytorch_lightning/utilities/torchtext_batch.py b/pytorch_lightning/utilities/torchtext_batch.py deleted file mode 100644 index 444d7f3afda48..0000000000000 --- a/pytorch_lightning/utilities/torchtext_batch.py +++ /dev/null @@ -1,101 +0,0 @@ -# THIS IS PURE COPY of TORCHTEXT BATCH CLASS -# THIS STRUCTURE SEEMS TO BE DEPRECATED IN TORCHTEXT - -import torch - - -class Batch(object): - """Defines a batch of examples along with its Fields. - - Attributes: - batch_size: Number of examples in the batch. - dataset: A reference to the dataset object the examples come from - (which itself contains the dataset's Field objects). - train: Deprecated: this attribute is left for backwards compatibility, - however it is UNUSED as of the merger with pytorch 0.4. - input_fields: The names of the fields that are used as input for the model - target_fields: The names of the fields that are used as targets during model training - - Also stores the Variable for each column in the batch as an attribute. - """ - - def __init__(self, data=None, dataset=None, device=None): - """Create a Batch from a list of examples.""" - if data is not None: - self.batch_size = len(data) - self.dataset = dataset - self.fields = dataset.fields.keys() # copy field names - self.input_fields = [k for k, v in dataset.fields.items() if v is not None and not v.is_target] - self.target_fields = [k for k, v in dataset.fields.items() if v is not None and v.is_target] - - for (name, field) in dataset.fields.items(): - if field is not None: - batch = [getattr(x, name) for x in data] - setattr(self, name, field.process(batch, device=device)) - - @classmethod - def fromvars(cls, dataset, batch_size, train=None, **kwargs): - """Create a Batch directly from a number of Variables.""" - batch = cls() - batch.batch_size = batch_size - batch.dataset = dataset - batch.fields = dataset.fields.keys() - for k, v in kwargs.items(): - setattr(batch, k, v) - return batch - - def __repr__(self): - return str(self) - - def __str__(self): - if not self.__dict__: - return 'Empty {} instance'.format(torch.typename(self)) - - fields_to_index = filter(lambda field: field is not None, self.fields) - var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name)) - for name in fields_to_index if hasattr(self, name)]) - - data_str = (' from {}'.format(self.dataset.name.upper()) - if hasattr(self.dataset, 'name') - and isinstance(self.dataset.name, str) else '') - - strt = '[{} of size {}{}]\n{}'.format(torch.typename(self), - self.batch_size, data_str, var_strs) - return '\n' + strt - - def __len__(self): - return self.batch_size - - def _get_field_values(self, fields): - if len(fields) == 0: - return None - elif len(fields) == 1: - return getattr(self, fields[0]) - else: - return tuple(getattr(self, f) for f in fields) - - def __iter__(self): - yield self._get_field_values(self.input_fields) - yield self._get_field_values(self.target_fields) - - -def _short_str(tensor): - # unwrap variable to tensor - if not torch.is_tensor(tensor): - # (1) unpack variable - if hasattr(tensor, 'data'): - tensor = getattr(tensor, 'data') - # (2) handle include_lengths - elif isinstance(tensor, tuple): - return str(tuple(_short_str(t) for t in tensor)) - # (3) fallback to default str - else: - return str(tensor) - - # copied from torch _tensor_str - size_str = 'x'.join(str(size) for size in tensor.size()) - device_str = '' if not tensor.is_cuda else \ - ' (GPU {})'.format(tensor.get_device()) - strt = '[{} of size {}{}]'.format(torch.typename(tensor), - size_str, device_str) - return strt diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index d225c9523f0df..fa5b822271533 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -15,7 +15,7 @@ import torch from pytorch_lightning.utilities.apply_func import move_data_to_device -from tests.helpers.imports import Batch, Dataset, Example, Field, Iterator +from tests.helpers.imports import Dataset, Example, Field, Iterator from tests.helpers.runif import RunIf From 3870459f03b558b8dfafde32c607af43326ada79 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 20:13:58 +0100 Subject: [PATCH 3/6] rev --- requirements/extra.txt | 1 + requirements/test.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 50cba404c88a7..85437327bce06 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -3,6 +3,7 @@ matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 +torchtext>=0.5 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 diff --git a/requirements/test.txt b/requirements/test.txt index 955e24606ddda..2d47143ca58d4 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -14,5 +14,4 @@ pre-commit>=1.0 cloudpickle>=1.3 nltk>=3.3 -torchtext>=0.5 pandas # needed in benchmarks From 57ba543f8c8b37733152379f25cace2c7a404e76 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 20:14:19 +0100 Subject: [PATCH 4/6] rev --- pytorch_lightning/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 16e0e5535278b..4d2c5f6700024 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -22,7 +22,7 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE, _compare_version +from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: if _compare_version("torch", operator.ge, "0.9.0"): From cbc192b1281de69da505a5da6ce5a9558b8d049d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 20:27:43 +0100 Subject: [PATCH 5/6] . --- pytorch_lightning/utilities/apply_func.py | 2 +- tests/helpers/imports.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 4d2c5f6700024..9fd42008b9d8d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -25,7 +25,7 @@ from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: - if _compare_version("torch", operator.ge, "0.9.0"): + if _compare_version("torchtext", operator.ge, "0.9.0"): from torchtext.legacy.data import Batch else: from torchtext.data import Batch diff --git a/tests/helpers/imports.py b/tests/helpers/imports.py index e6c6fd18b605c..4db9c00d45eab 100644 --- a/tests/helpers/imports.py +++ b/tests/helpers/imports.py @@ -2,7 +2,7 @@ from pytorch_lightning.utilities.imports import _compare_version -if _compare_version("torch", operator.ge, "0.9.0"): +if _compare_version("torchtext", operator.ge, "0.9.0"): from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 else: from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 From 532be23a78c68775c7f7e14df075898a7793c722 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 5 Mar 2021 20:43:38 +0100 Subject: [PATCH 6/6] docs --- .github/workflows/docs-checks.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 347f20196d974..5ee4f23b4b3cc 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -41,6 +41,8 @@ jobs: - name: Install dependencies run: | + python --version + pip --version # remove Horovod from requirements python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" # python -m pip install --upgrade --user pip @@ -48,8 +50,6 @@ jobs: pip install --requirement requirements/extra.txt pip install --requirement requirements/loggers.txt pip install --requirement requirements/docs.txt - python --version - pip --version pip list shell: bash @@ -84,12 +84,12 @@ jobs: - name: Install dependencies run: | - pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + python --version + pip --version + # pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet pip install --requirement requirements/docs.txt # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures - python --version - pip --version pip list shell: bash